diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d7e8c60be5c..9dff2d46700 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -182,3 +182,24 @@ jobs: - name: Run ipfs-kad example run: RUST_LOG=libp2p_swarm=debug,libp2p_kad=trace,libp2p_tcp=debug cargo run --example ipfs-kad + + rustfmt: + runs-on: ubuntu-latest + steps: + + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.9.0 + with: + access_token: ${{ github.token }} + + - uses: actions/checkout@v2.3.4 + + - uses: actions-rs/toolchain@v1.0.7 + with: + profile: minimal + toolchain: stable + override: true + components: rustfmt + + - name: Check formatting + run: cargo fmt -- --check diff --git a/core/benches/peer_id.rs b/core/benches/peer_id.rs index 5dfb0d7c132..9a6935113ec 100644 --- a/core/benches/peer_id.rs +++ b/core/benches/peer_id.rs @@ -35,9 +35,7 @@ fn from_bytes(c: &mut Criterion) { } fn clone(c: &mut Criterion) { - let peer_id = identity::Keypair::generate_ed25519() - .public() - .to_peer_id(); + let peer_id = identity::Keypair::generate_ed25519().public().to_peer_id(); c.bench_function("clone", |b| { b.iter(|| { @@ -48,11 +46,7 @@ fn clone(c: &mut Criterion) { fn sort_vec(c: &mut Criterion) { let peer_ids: Vec<_> = (0..100) - .map(|_| { - identity::Keypair::generate_ed25519() - .public() - .to_peer_id() - }) + .map(|_| identity::Keypair::generate_ed25519().public().to_peer_id()) .collect(); c.bench_function("sort_vec", |b| { diff --git a/core/build.rs b/core/build.rs index c08517dee58..9692abd9c81 100644 --- a/core/build.rs +++ b/core/build.rs @@ -19,5 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/keys.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/keys.proto"], &["src"]).unwrap(); } diff --git a/core/src/connection.rs b/core/src/connection.rs index 50b44b86ccd..335e2046c2d 100644 --- a/core/src/connection.rs +++ b/core/src/connection.rs @@ -28,16 +28,16 @@ pub(crate) mod pool; pub use error::{ConnectionError, PendingConnectionError}; pub use handler::{ConnectionHandler, ConnectionHandlerEvent, IntoConnectionHandler}; -pub use listeners::{ListenerId, ListenersStream, ListenersEvent}; +pub use listeners::{ListenerId, ListenersEvent, ListenersStream}; pub use manager::ConnectionId; -pub use substream::{Substream, SubstreamEndpoint, Close}; +pub use pool::{ConnectionCounters, ConnectionLimits}; pub use pool::{EstablishedConnection, EstablishedConnectionIter, PendingConnection}; -pub use pool::{ConnectionLimits, ConnectionCounters}; +pub use substream::{Close, Substream, SubstreamEndpoint}; use crate::muxing::StreamMuxer; use crate::{Multiaddr, PeerId}; -use std::{error::Error, fmt, pin::Pin, task::Context, task::Poll}; use std::hash::Hash; +use std::{error::Error, fmt, pin::Pin, task::Context, task::Poll}; use substream::{Muxing, SubstreamEvent}; /// The endpoint roles associated with a peer-to-peer communication channel. @@ -55,7 +55,7 @@ impl std::ops::Not for Endpoint { fn not(self) -> Self::Output { match self { Endpoint::Dialer => Endpoint::Listener, - Endpoint::Listener => Endpoint::Dialer + Endpoint::Listener => Endpoint::Dialer, } } } @@ -86,7 +86,7 @@ pub enum ConnectedPoint { local_addr: Multiaddr, /// Stack of protocols used to send back data to the remote. send_back_addr: Multiaddr, - } + }, } impl From<&'_ ConnectedPoint> for Endpoint { @@ -106,7 +106,7 @@ impl ConnectedPoint { pub fn to_endpoint(&self) -> Endpoint { match self { ConnectedPoint::Dialer { .. } => Endpoint::Dialer, - ConnectedPoint::Listener { .. } => Endpoint::Listener + ConnectedPoint::Listener { .. } => Endpoint::Listener, } } @@ -114,7 +114,7 @@ impl ConnectedPoint { pub fn is_dialer(&self) -> bool { match self { ConnectedPoint::Dialer { .. } => true, - ConnectedPoint::Listener { .. } => false + ConnectedPoint::Listener { .. } => false, } } @@ -122,7 +122,7 @@ impl ConnectedPoint { pub fn is_listener(&self) -> bool { match self { ConnectedPoint::Dialer { .. } => false, - ConnectedPoint::Listener { .. } => true + ConnectedPoint::Listener { .. } => true, } } @@ -237,9 +237,10 @@ where /// Polls the connection for events produced by the associated handler /// as a result of I/O activity on the substream multiplexer. - pub fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll, ConnectionError>> - { + pub fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, ConnectionError>> { loop { let mut io_pending = false; @@ -247,10 +248,13 @@ where // of new substreams. match self.muxing.poll(cx) { Poll::Pending => io_pending = true, - Poll::Ready(Ok(SubstreamEvent::InboundSubstream { substream })) => { - self.handler.inject_substream(substream, SubstreamEndpoint::Listener) - } - Poll::Ready(Ok(SubstreamEvent::OutboundSubstream { user_data, substream })) => { + Poll::Ready(Ok(SubstreamEvent::InboundSubstream { substream })) => self + .handler + .inject_substream(substream, SubstreamEndpoint::Listener), + Poll::Ready(Ok(SubstreamEvent::OutboundSubstream { + user_data, + substream, + })) => { let endpoint = SubstreamEndpoint::Dialer(user_data); self.handler.inject_substream(substream, endpoint) } @@ -265,7 +269,7 @@ where match self.handler.poll(cx) { Poll::Pending => { if io_pending { - return Poll::Pending // Nothing to do + return Poll::Pending; // Nothing to do } } Poll::Ready(Ok(ConnectionHandlerEvent::OutboundSubstreamRequest(user_data))) => { @@ -310,7 +314,7 @@ impl<'a> OutgoingInfo<'a> { /// Builds a `ConnectedPoint` corresponding to the outgoing connection. pub fn to_connected_point(&self) -> ConnectedPoint { ConnectedPoint::Dialer { - address: self.address.clone() + address: self.address.clone(), } } } diff --git a/core/src/connection/error.rs b/core/src/connection/error.rs index 1836965e43e..66da0670c98 100644 --- a/core/src/connection/error.rs +++ b/core/src/connection/error.rs @@ -20,7 +20,7 @@ use crate::connection::ConnectionLimit; use crate::transport::TransportError; -use std::{io, fmt}; +use std::{fmt, io}; /// Errors that can occur in the context of an established `Connection`. #[derive(Debug)] @@ -33,23 +33,19 @@ pub enum ConnectionError { Handler(THandlerErr), } -impl fmt::Display -for ConnectionError +impl fmt::Display for ConnectionError where THandlerErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ConnectionError::IO(err) => - write!(f, "Connection error: I/O error: {}", err), - ConnectionError::Handler(err) => - write!(f, "Connection error: Handler error: {}", err), + ConnectionError::IO(err) => write!(f, "Connection error: I/O error: {}", err), + ConnectionError::Handler(err) => write!(f, "Connection error: Handler error: {}", err), } } } -impl std::error::Error -for ConnectionError +impl std::error::Error for ConnectionError where THandlerErr: std::error::Error + 'static, { @@ -80,29 +76,29 @@ pub enum PendingConnectionError { IO(io::Error), } -impl fmt::Display -for PendingConnectionError +impl fmt::Display for PendingConnectionError where TTransErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - PendingConnectionError::IO(err) => - write!(f, "Pending connection: I/O error: {}", err), - PendingConnectionError::Transport(err) => - write!(f, "Pending connection: Transport error: {}", err), - PendingConnectionError::InvalidPeerId => - write!(f, "Pending connection: Invalid peer ID."), - PendingConnectionError::ConnectionLimit(l) => - write!(f, "Connection error: Connection limit: {}.", l), + PendingConnectionError::IO(err) => write!(f, "Pending connection: I/O error: {}", err), + PendingConnectionError::Transport(err) => { + write!(f, "Pending connection: Transport error: {}", err) + } + PendingConnectionError::InvalidPeerId => { + write!(f, "Pending connection: Invalid peer ID.") + } + PendingConnectionError::ConnectionLimit(l) => { + write!(f, "Connection error: Connection limit: {}.", l) + } } } } -impl std::error::Error -for PendingConnectionError +impl std::error::Error for PendingConnectionError where - TTransErr: std::error::Error + 'static + TTransErr: std::error::Error + 'static, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { diff --git a/core/src/connection/handler.rs b/core/src/connection/handler.rs index 0f1c2f6bcd8..011dcc2b61e 100644 --- a/core/src/connection/handler.rs +++ b/core/src/connection/handler.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use super::{Connected, SubstreamEndpoint}; use crate::Multiaddr; use std::{fmt::Debug, task::Context, task::Poll}; -use super::{Connected, SubstreamEndpoint}; /// The interface of a connection handler. /// @@ -53,7 +53,11 @@ pub trait ConnectionHandler { /// Implementations are allowed to panic in the case of dialing if the `user_data` in /// `endpoint` doesn't correspond to what was returned earlier when polling, or is used /// multiple times. - fn inject_substream(&mut self, substream: Self::Substream, endpoint: SubstreamEndpoint); + fn inject_substream( + &mut self, + substream: Self::Substream, + endpoint: SubstreamEndpoint, + ); /// Notifies the handler of an event. fn inject_event(&mut self, event: Self::InEvent); @@ -64,8 +68,10 @@ pub trait ConnectionHandler { /// Polls the handler for events. /// /// Returning an error will close the connection to the remote. - fn poll(&mut self, cx: &mut Context<'_>) - -> Poll, Self::Error>>; + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>; } /// Prototype for a `ConnectionHandler`. @@ -82,7 +88,7 @@ pub trait IntoConnectionHandler { impl IntoConnectionHandler for T where - T: ConnectionHandler + T: ConnectionHandler, { type Handler = Self; @@ -91,9 +97,12 @@ where } } -pub(crate) type THandlerInEvent = <::Handler as ConnectionHandler>::InEvent; -pub(crate) type THandlerOutEvent = <::Handler as ConnectionHandler>::OutEvent; -pub(crate) type THandlerError = <::Handler as ConnectionHandler>::Error; +pub(crate) type THandlerInEvent = + <::Handler as ConnectionHandler>::InEvent; +pub(crate) type THandlerOutEvent = + <::Handler as ConnectionHandler>::OutEvent; +pub(crate) type THandlerError = + <::Handler as ConnectionHandler>::Error; /// Event produced by a handler. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -109,24 +118,26 @@ pub enum ConnectionHandlerEvent { impl ConnectionHandlerEvent { /// If this is `OutboundSubstreamRequest`, maps the content to something else. pub fn map_outbound_open_info(self, map: F) -> ConnectionHandlerEvent - where F: FnOnce(TOutboundOpenInfo) -> I + where + F: FnOnce(TOutboundOpenInfo) -> I, { match self { ConnectionHandlerEvent::OutboundSubstreamRequest(val) => { ConnectionHandlerEvent::OutboundSubstreamRequest(map(val)) - }, + } ConnectionHandlerEvent::Custom(val) => ConnectionHandlerEvent::Custom(val), } } /// If this is `Custom`, maps the content to something else. pub fn map_custom(self, map: F) -> ConnectionHandlerEvent - where F: FnOnce(TCustom) -> I + where + F: FnOnce(TCustom) -> I, { match self { ConnectionHandlerEvent::OutboundSubstreamRequest(val) => { ConnectionHandlerEvent::OutboundSubstreamRequest(val) - }, + } ConnectionHandlerEvent::Custom(val) => ConnectionHandlerEvent::Custom(map(val)), } } diff --git a/core/src/connection/listeners.rs b/core/src/connection/listeners.rs index 02982d87393..cf6daa17f5f 100644 --- a/core/src/connection/listeners.rs +++ b/core/src/connection/listeners.rs @@ -20,7 +20,10 @@ //! Manage listening on multiple multiaddresses at once. -use crate::{Multiaddr, Transport, transport::{TransportError, ListenerEvent}}; +use crate::{ + transport::{ListenerEvent, TransportError}, + Multiaddr, Transport, +}; use futures::{prelude::*, task::Context, task::Poll}; use log::debug; use smallvec::SmallVec; @@ -86,7 +89,7 @@ where /// can be resized, the only way is to use a `Pin>`. listeners: VecDeque>>>, /// The next listener ID to assign. - next_id: ListenerId + next_id: ListenerId, } /// The ID of a single listener. @@ -109,7 +112,7 @@ where #[pin] listener: TTrans::Listener, /// Addresses it is listening on. - addresses: SmallVec<[Multiaddr; 4]> + addresses: SmallVec<[Multiaddr; 4]>, } /// Event that can happen on the `ListenersStream`. @@ -122,14 +125,14 @@ where /// The listener that is listening on the new address. listener_id: ListenerId, /// The new address that is being listened on. - listen_addr: Multiaddr + listen_addr: Multiaddr, }, /// An address is no longer being listened on. AddressExpired { /// The listener that is no longer listening on the address. listener_id: ListenerId, /// The new address that is being listened on. - listen_addr: Multiaddr + listen_addr: Multiaddr, }, /// A connection is incoming on one of the listeners. Incoming { @@ -161,7 +164,7 @@ where listener_id: ListenerId, /// The error value. error: TTrans::Error, - } + }, } impl ListenersStream @@ -173,7 +176,7 @@ where ListenersStream { transport, listeners: VecDeque::new(), - next_id: ListenerId(1) + next_id: ListenerId(1), } } @@ -183,14 +186,17 @@ where ListenersStream { transport, listeners: VecDeque::with_capacity(capacity), - next_id: ListenerId(1) + next_id: ListenerId(1), } } /// Start listening on a multiaddress. /// /// Returns an error if the transport doesn't support the given multiaddress. - pub fn listen_on(&mut self, addr: Multiaddr) -> Result> + pub fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> where TTrans: Clone, { @@ -198,7 +204,7 @@ where self.listeners.push_back(Box::pin(Listener { id: self.next_id, listener, - addresses: SmallVec::new() + addresses: SmallVec::new(), })); let id = self.next_id; self.next_id = ListenerId(self.next_id.0 + 1); @@ -237,17 +243,23 @@ where Poll::Pending => { self.listeners.push_front(listener); remaining -= 1; - if remaining == 0 { break } + if remaining == 0 { + break; + } } - Poll::Ready(Some(Ok(ListenerEvent::Upgrade { upgrade, local_addr, remote_addr }))) => { + Poll::Ready(Some(Ok(ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + }))) => { let id = *listener_project.id; self.listeners.push_front(listener); return Poll::Ready(ListenersEvent::Incoming { listener_id: id, upgrade, local_addr, - send_back_addr: remote_addr - }) + send_back_addr: remote_addr, + }); } Poll::Ready(Some(Ok(ListenerEvent::NewAddress(a)))) => { if listener_project.addresses.contains(&a) { @@ -260,8 +272,8 @@ where self.listeners.push_front(listener); return Poll::Ready(ListenersEvent::NewAddress { listener_id: id, - listen_addr: a - }) + listen_addr: a, + }); } Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(a)))) => { listener_project.addresses.retain(|x| x != &a); @@ -269,8 +281,8 @@ where self.listeners.push_front(listener); return Poll::Ready(ListenersEvent::AddressExpired { listener_id: id, - listen_addr: a - }) + listen_addr: a, + }); } Poll::Ready(Some(Ok(ListenerEvent::Error(error)))) => { let id = *listener_project.id; @@ -278,7 +290,7 @@ where return Poll::Ready(ListenersEvent::Error { listener_id: id, error, - }) + }); } Poll::Ready(None) => { return Poll::Ready(ListenersEvent::Closed { @@ -313,11 +325,7 @@ where } } -impl Unpin for ListenersStream -where - TTrans: Transport, -{ -} +impl Unpin for ListenersStream where TTrans: Transport {} impl fmt::Debug for ListenersStream where @@ -338,22 +346,36 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - ListenersEvent::NewAddress { listener_id, listen_addr } => f + ListenersEvent::NewAddress { + listener_id, + listen_addr, + } => f .debug_struct("ListenersEvent::NewAddress") .field("listener_id", listener_id) .field("listen_addr", listen_addr) .finish(), - ListenersEvent::AddressExpired { listener_id, listen_addr } => f + ListenersEvent::AddressExpired { + listener_id, + listen_addr, + } => f .debug_struct("ListenersEvent::AddressExpired") .field("listener_id", listener_id) .field("listen_addr", listen_addr) .finish(), - ListenersEvent::Incoming { listener_id, local_addr, .. } => f + ListenersEvent::Incoming { + listener_id, + local_addr, + .. + } => f .debug_struct("ListenersEvent::Incoming") .field("listener_id", listener_id) .field("local_addr", local_addr) .finish(), - ListenersEvent::Closed { listener_id, addresses, reason } => f + ListenersEvent::Closed { + listener_id, + addresses, + reason, + } => f .debug_struct("ListenersEvent::Closed") .field("listener_id", listener_id) .field("addresses", addresses) @@ -363,13 +385,15 @@ where .debug_struct("ListenersEvent::Error") .field("listener_id", listener_id) .field("error", error) - .finish() + .finish(), } } } #[cfg(test)] mod tests { + use futures::{future::BoxFuture, stream::BoxStream}; + use super::*; use crate::transport; @@ -396,11 +420,15 @@ mod tests { }); match listeners.next().await.unwrap() { - ListenersEvent::Incoming { local_addr, send_back_addr, .. } => { + ListenersEvent::Incoming { + local_addr, + send_back_addr, + .. + } => { assert_eq!(local_addr, address); assert!(send_back_addr != address); - }, - _ => panic!() + } + _ => panic!(), } }); } @@ -415,21 +443,37 @@ mod tests { impl transport::Transport for DummyTrans { type Output = (); type Error = std::io::Error; - type Listener = Pin, std::io::Error>>>>; - type ListenerUpgrade = Pin>>>; - type Dial = Pin>>>; - - fn listen_on(self, _: Multiaddr) -> Result> { + type Listener = BoxStream< + 'static, + Result, std::io::Error>, + >; + type ListenerUpgrade = BoxFuture<'static, Result>; + type Dial = BoxFuture<'static, Result>; + + fn listen_on( + self, + _: Multiaddr, + ) -> Result> { Ok(Box::pin(stream::unfold((), |()| async move { - Some((Ok(ListenerEvent::Error(std::io::Error::from(std::io::ErrorKind::Other))), ())) + Some(( + Ok(ListenerEvent::Error(std::io::Error::from( + std::io::ErrorKind::Other, + ))), + (), + )) }))) } - fn dial(self, _: Multiaddr) -> Result> { + fn dial( + self, + _: Multiaddr, + ) -> Result> { panic!() } - fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { None } + fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { + None + } } async_std::task::block_on(async move { @@ -439,8 +483,8 @@ mod tests { for _ in 0..10 { match listeners.next().await.unwrap() { - ListenersEvent::Error { .. } => {}, - _ => panic!() + ListenersEvent::Error { .. } => {} + _ => panic!(), } } }); @@ -455,21 +499,32 @@ mod tests { impl transport::Transport for DummyTrans { type Output = (); type Error = std::io::Error; - type Listener = Pin, std::io::Error>>>>; - type ListenerUpgrade = Pin>>>; - type Dial = Pin>>>; - - fn listen_on(self, _: Multiaddr) -> Result> { + type Listener = BoxStream< + 'static, + Result, std::io::Error>, + >; + type ListenerUpgrade = BoxFuture<'static, Result>; + type Dial = BoxFuture<'static, Result>; + + fn listen_on( + self, + _: Multiaddr, + ) -> Result> { Ok(Box::pin(stream::unfold((), |()| async move { Some((Err(std::io::Error::from(std::io::ErrorKind::Other)), ())) }))) } - fn dial(self, _: Multiaddr) -> Result> { + fn dial( + self, + _: Multiaddr, + ) -> Result> { panic!() } - fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { None } + fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { + None + } } async_std::task::block_on(async move { @@ -478,8 +533,8 @@ mod tests { listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); match listeners.next().await.unwrap() { - ListenersEvent::Closed { .. } => {}, - _ => panic!() + ListenersEvent::Closed { .. } => {} + _ => panic!(), } }); } diff --git a/core/src/connection/manager.rs b/core/src/connection/manager.rs index b450f0d602f..1d7acb92e69 100644 --- a/core/src/connection/manager.rs +++ b/core/src/connection/manager.rs @@ -18,39 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{ - Executor, - muxing::StreamMuxer, +use super::{ + handler::{THandlerError, THandlerInEvent, THandlerOutEvent}, + Connected, ConnectedPoint, Connection, ConnectionError, ConnectionHandler, + IntoConnectionHandler, PendingConnectionError, Substream, }; +use crate::{muxing::StreamMuxer, Executor}; use fnv::FnvHashMap; -use futures::{ - prelude::*, - channel::mpsc, - stream::FuturesUnordered -}; +use futures::{channel::mpsc, prelude::*, stream::FuturesUnordered}; use std::{ collections::hash_map, - error, - fmt, - mem, + error, fmt, mem, pin::Pin, task::{Context, Poll}, }; -use super::{ - Connected, - ConnectedPoint, - Connection, - ConnectionError, - ConnectionHandler, - IntoConnectionHandler, - PendingConnectionError, - Substream, - handler::{ - THandlerInEvent, - THandlerOutEvent, - THandlerError, - }, -}; use task::{Task, TaskId}; mod task; @@ -123,11 +104,10 @@ pub struct Manager { events_tx: mpsc::Sender>, /// Receiver for events reported from managed tasks. - events_rx: mpsc::Receiver> + events_rx: mpsc::Receiver>, } -impl fmt::Debug for Manager -{ +impl fmt::Debug for Manager { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_map() .entries(self.tasks.iter().map(|(id, task)| (id, &task.state))) @@ -196,7 +176,7 @@ pub enum Event<'a, H: IntoConnectionHandler, TE> { /// What happened. error: PendingConnectionError, /// The handler that was supposed to handle the failed connection. - handler: H + handler: H, }, /// An established connection has been closed. @@ -225,7 +205,7 @@ pub enum Event<'a, H: IntoConnectionHandler, TE> { /// The entry associated with the connection that produced the event. entry: EstablishedEntry<'a, THandlerInEvent>, /// The produced event. - event: THandlerOutEvent + event: THandlerOutEvent, }, /// A connection to a node has changed its address. @@ -250,7 +230,7 @@ impl Manager { executor: config.executor, local_spawns: FuturesUnordered::new(), events_tx: tx, - events_rx: rx + events_rx: rx, } } @@ -265,18 +245,28 @@ impl Manager { M::OutboundSubstream: Send + 'static, F: Future> + Send + 'static, H: IntoConnectionHandler + Send + 'static, - H::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + H::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, { let task_id = self.next_task_id; self.next_task_id.0 += 1; let (tx, rx) = mpsc::channel(self.task_command_buffer_size); - self.tasks.insert(task_id, TaskInfo { sender: tx, state: TaskState::Pending }); - - let task = Box::pin(Task::pending(task_id, self.events_tx.clone(), rx, future, handler)); + self.tasks.insert( + task_id, + TaskInfo { + sender: tx, + state: TaskState::Pending, + }, + ); + + let task = Box::pin(Task::pending( + task_id, + self.events_tx.clone(), + rx, + future, + handler, + )); if let Some(executor) = &mut self.executor { executor.exec(task); } else { @@ -290,9 +280,7 @@ impl Manager { pub fn add(&mut self, conn: Connection, info: Connected) -> ConnectionId where H: IntoConnectionHandler + Send + 'static, - H::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + H::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TE: error::Error + Send + 'static, M: StreamMuxer + Send + Sync + 'static, @@ -302,9 +290,13 @@ impl Manager { self.next_task_id.0 += 1; let (tx, rx) = mpsc::channel(self.task_command_buffer_size); - self.tasks.insert(task_id, TaskInfo { - sender: tx, state: TaskState::Established(info) - }); + self.tasks.insert( + task_id, + TaskInfo { + sender: tx, + state: TaskState::Established(info), + }, + ); let task: Pin>>, _, _, _>>> = Box::pin(Task::established(task_id, self.events_tx.clone(), rx, conn)); @@ -329,7 +321,13 @@ impl Manager { /// Checks whether an established connection with the given ID is currently managed. pub fn is_established(&self, id: &ConnectionId) -> bool { - matches!(self.tasks.get(&id.0), Some(TaskInfo { state: TaskState::Established(..), .. })) + matches!( + self.tasks.get(&id.0), + Some(TaskInfo { + state: TaskState::Established(..), + .. + }) + ) } /// Polls the manager for events relating to the managed connections. @@ -341,8 +339,9 @@ impl Manager { let event = loop { match self.events_rx.poll_next_unpin(cx) { Poll::Ready(Some(event)) => { - if self.tasks.contains_key(event.id()) { // (1) - break event + if self.tasks.contains_key(event.id()) { + // (1) + break event; } } Poll::Pending => return Poll::Pending, @@ -352,12 +351,12 @@ impl Manager { if let hash_map::Entry::Occupied(mut task) = self.tasks.entry(*event.id()) { Poll::Ready(match event { - task::Event::Notify { id: _, event } => - Event::ConnectionEvent { - entry: EstablishedEntry { task }, - event - }, - task::Event::Established { id: _, info } => { // (2) + task::Event::Notify { id: _, event } => Event::ConnectionEvent { + entry: EstablishedEntry { task }, + event, + }, + task::Event::Established { id: _, info } => { + // (2) task.get_mut().state = TaskState::Established(info); // (3) Event::ConnectionEstablished { entry: EstablishedEntry { task }, @@ -389,11 +388,14 @@ impl Manager { let id = ConnectionId(id); let task = task.remove(); match task.state { - TaskState::Established(connected) => - Event::ConnectionClosed { id, connected, error }, + TaskState::Established(connected) => Event::ConnectionClosed { + id, + connected, + error, + }, TaskState::Pending => unreachable!( "`Event::Closed` implies (2) occurred on that task and thus (3)." - ), + ), } } }) @@ -407,14 +409,14 @@ impl Manager { #[derive(Debug)] pub enum Entry<'a, I> { Pending(PendingEntry<'a, I>), - Established(EstablishedEntry<'a, I>) + Established(EstablishedEntry<'a, I>), } impl<'a, I> Entry<'a, I> { fn new(task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo>) -> Self { match &task.get().state { TaskState::Pending => Entry::Pending(PendingEntry { task }), - TaskState::Established(_) => Entry::Established(EstablishedEntry { task }) + TaskState::Established(_) => Entry::Established(EstablishedEntry { task }), } } } @@ -442,10 +444,13 @@ impl<'a, I> EstablishedEntry<'a, I> { /// > the connection handler not being ready at this time. pub fn notify_handler(&mut self, event: I) -> Result<(), I> { let cmd = task::Command::NotifyHandler(event); // (*) - self.task.get_mut().sender.try_send(cmd) + self.task + .get_mut() + .sender + .try_send(cmd) .map_err(|e| match e.into_inner() { task::Command::NotifyHandler(event) => event, - _ => panic!("Unexpected command. Expected `NotifyHandler`") // see (*) + _ => panic!("Unexpected command. Expected `NotifyHandler`"), // see (*) }) } @@ -455,7 +460,7 @@ impl<'a, I> EstablishedEntry<'a, I> { /// /// Returns `Err(())` if the background task associated with the connection /// is terminating and the connection is about to close. - pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll> { + pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll> { self.task.get_mut().sender.poll_ready(cx).map_err(|_| ()) } @@ -469,9 +474,15 @@ impl<'a, I> EstablishedEntry<'a, I> { pub fn start_close(mut self) { // Clone the sender so that we are guaranteed to have // capacity for the close command (every sender gets a slot). - match self.task.get_mut().sender.clone().try_send(task::Command::Close) { - Ok(()) => {}, - Err(e) => assert!(e.is_disconnected(), "No capacity for close command.") + match self + .task + .get_mut() + .sender + .clone() + .try_send(task::Command::Close) + { + Ok(()) => {} + Err(e) => assert!(e.is_disconnected(), "No capacity for close command."), } } @@ -479,7 +490,7 @@ impl<'a, I> EstablishedEntry<'a, I> { pub fn connected(&self) -> &Connected { match &self.task.get().state { TaskState::Established(c) => c, - TaskState::Pending => unreachable!("By Entry::new()") + TaskState::Pending => unreachable!("By Entry::new()"), } } @@ -490,7 +501,7 @@ impl<'a, I> EstablishedEntry<'a, I> { pub fn remove(self) -> Connected { match self.task.remove().state { TaskState::Established(c) => c, - TaskState::Pending => unreachable!("By Entry::new()") + TaskState::Pending => unreachable!("By Entry::new()"), } } @@ -504,7 +515,7 @@ impl<'a, I> EstablishedEntry<'a, I> { /// (i.e. pending). #[derive(Debug)] pub struct PendingEntry<'a, I> { - task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo> + task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo>, } impl<'a, I> PendingEntry<'a, I> { @@ -514,7 +525,7 @@ impl<'a, I> PendingEntry<'a, I> { } /// Aborts the pending connection attempt. - pub fn abort(self) { + pub fn abort(self) { self.task.remove(); } } diff --git a/core/src/connection/manager/task.rs b/core/src/connection/manager/task.rs index a7bdbd3cbbd..db8fb43adb6 100644 --- a/core/src/connection/manager/task.rs +++ b/core/src/connection/manager/task.rs @@ -18,29 +18,19 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use super::ConnectResult; use crate::{ - Multiaddr, - muxing::StreamMuxer, connection::{ self, - Close, - Connected, - Connection, - ConnectionError, - ConnectionHandler, - IntoConnectionHandler, - PendingConnectionError, - Substream, - handler::{ - THandlerInEvent, - THandlerOutEvent, - THandlerError, - }, + handler::{THandlerError, THandlerInEvent, THandlerOutEvent}, + Close, Connected, Connection, ConnectionError, ConnectionHandler, IntoConnectionHandler, + PendingConnectionError, Substream, }, + muxing::StreamMuxer, + Multiaddr, }; -use futures::{prelude::*, channel::mpsc, stream}; +use futures::{channel::mpsc, prelude::*, stream}; use std::{pin::Pin, task::Context, task::Poll}; -use super::ConnectResult; /// Identifier of a [`Task`] in a [`Manager`](super::Manager). #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] @@ -62,16 +52,26 @@ pub enum Event { /// A connection to a node has succeeded. Established { id: TaskId, info: Connected }, /// A pending connection failed. - Failed { id: TaskId, error: PendingConnectionError, handler: H }, + Failed { + id: TaskId, + error: PendingConnectionError, + handler: H, + }, /// A node we are connected to has changed its address. AddressChange { id: TaskId, new_address: Multiaddr }, /// Notify the manager of an event from the connection. - Notify { id: TaskId, event: THandlerOutEvent }, + Notify { + id: TaskId, + event: THandlerOutEvent, + }, /// A connection closed, possibly due to an error. /// /// If `error` is `None`, the connection has completed /// an active orderly close. - Closed { id: TaskId, error: Option>> } + Closed { + id: TaskId, + error: Option>>, + }, } impl Event { @@ -91,7 +91,7 @@ pub struct Task where M: StreamMuxer, H: IntoConnectionHandler, - H::Handler: ConnectionHandler> + H::Handler: ConnectionHandler>, { /// The ID of this task. id: TaskId, @@ -110,7 +110,7 @@ impl Task where M: StreamMuxer, H: IntoConnectionHandler, - H::Handler: ConnectionHandler> + H::Handler: ConnectionHandler>, { /// Create a new task to connect and handle some node. pub fn pending( @@ -118,7 +118,7 @@ where events: mpsc::Sender>, commands: mpsc::Receiver>>, future: F, - handler: H + handler: H, ) -> Self { Task { id, @@ -136,13 +136,16 @@ where id: TaskId, events: mpsc::Sender>, commands: mpsc::Receiver>>, - connection: Connection + connection: Connection, ) -> Self { Task { id, events, commands: commands.fuse(), - state: State::Established { connection, event: None }, + state: State::Established { + connection, + event: None, + }, } } } @@ -152,7 +155,7 @@ enum State where M: StreamMuxer, H: IntoConnectionHandler, - H::Handler: ConnectionHandler> + H::Handler: ConnectionHandler>, { /// The connection is being negotiated. Pending { @@ -180,14 +183,14 @@ where Terminating(Event), /// The task has finished. - Done + Done, } impl Unpin for Task where M: StreamMuxer, H: IntoConnectionHandler, - H::Handler: ConnectionHandler> + H::Handler: ConnectionHandler>, { } @@ -196,9 +199,7 @@ where M: StreamMuxer, F: Future>, H: IntoConnectionHandler, - H::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + H::Handler: ConnectionHandler> + Send + 'static, { type Output = (); @@ -211,33 +212,33 @@ where 'poll: loop { match std::mem::replace(&mut this.state, State::Done) { - State::Pending { mut future, handler } => { + State::Pending { + mut future, + handler, + } => { // Check whether the task is still registered with a `Manager` // by polling the commands channel. match this.commands.poll_next_unpin(cx) { - Poll::Pending => {}, + Poll::Pending => {} Poll::Ready(None) => { // The manager has dropped the task; abort. - return Poll::Ready(()) + return Poll::Ready(()); + } + Poll::Ready(Some(_)) => { + panic!("Task received command while the connection is pending.") } - Poll::Ready(Some(_)) => panic!( - "Task received command while the connection is pending." - ) } // Check if the connection succeeded. match future.poll_unpin(cx) { Poll::Ready(Ok((info, muxer))) => { this.state = State::Established { - connection: Connection::new( - muxer, - handler.into_handler(&info), - ), - event: Some(Event::Established { id, info }) + connection: Connection::new(muxer, handler.into_handler(&info)), + event: Some(Event::Established { id, info }), } } Poll::Pending => { this.state = State::Pending { future, handler }; - return Poll::Pending + return Poll::Pending; } Poll::Ready(Err(error)) => { // Don't accept any further commands and terminate the @@ -249,23 +250,27 @@ where } } - State::Established { mut connection, event } => { + State::Established { + mut connection, + event, + } => { // Check for commands from the `Manager`. loop { match this.commands.poll_next_unpin(cx) { Poll::Pending => break, - Poll::Ready(Some(Command::NotifyHandler(event))) => - connection.inject_event(event), + Poll::Ready(Some(Command::NotifyHandler(event))) => { + connection.inject_event(event) + } Poll::Ready(Some(Command::Close)) => { // Don't accept any further commands. this.commands.get_mut().close(); // Discard the event, if any, and start a graceful close. this.state = State::Closing(connection.close()); - continue 'poll + continue 'poll; } Poll::Ready(None) => { // The manager has dropped the task or disappeared; abort. - return Poll::Ready(()) + return Poll::Ready(()); } } } @@ -274,44 +279,56 @@ where // Send the event to the manager. match this.events.poll_ready(cx) { Poll::Pending => { - this.state = State::Established { connection, event: Some(event) }; - return Poll::Pending + this.state = State::Established { + connection, + event: Some(event), + }; + return Poll::Pending; } Poll::Ready(result) => { if result.is_ok() { if let Ok(()) = this.events.start_send(event) { - this.state = State::Established { connection, event: None }; - continue 'poll + this.state = State::Established { + connection, + event: None, + }; + continue 'poll; } } // The manager is no longer reachable; abort. - return Poll::Ready(()) + return Poll::Ready(()); } } } else { // Poll the connection for new events. match Connection::poll(Pin::new(&mut connection), cx) { Poll::Pending => { - this.state = State::Established { connection, event: None }; - return Poll::Pending + this.state = State::Established { + connection, + event: None, + }; + return Poll::Pending; } Poll::Ready(Ok(connection::Event::Handler(event))) => { this.state = State::Established { connection, - event: Some(Event::Notify { id, event }) + event: Some(Event::Notify { id, event }), }; } Poll::Ready(Ok(connection::Event::AddressChange(new_address))) => { this.state = State::Established { connection, - event: Some(Event::AddressChange { id, new_address }) + event: Some(Event::AddressChange { id, new_address }), }; } Poll::Ready(Err(error)) => { // Don't accept any further commands. this.commands.get_mut().close(); // Terminate the task with the error, dropping the connection. - let event = Event::Closed { id, error: Some(error) }; + let event = Event::Closed { + id, + error: Some(error), + }; this.state = State::Terminating(event); } } @@ -322,19 +339,22 @@ where // Try to gracefully close the connection. match closing.poll_unpin(cx) { Poll::Ready(Ok(())) => { - let event = Event::Closed { id: this.id, error: None }; + let event = Event::Closed { + id: this.id, + error: None, + }; this.state = State::Terminating(event); } Poll::Ready(Err(e)) => { let event = Event::Closed { id: this.id, - error: Some(ConnectionError::IO(e)) + error: Some(ConnectionError::IO(e)), }; this.state = State::Terminating(event); } Poll::Pending => { this.state = State::Closing(closing); - return Poll::Pending + return Poll::Pending; } } } @@ -344,18 +364,18 @@ where match this.events.poll_ready(cx) { Poll::Pending => { self.state = State::Terminating(event); - return Poll::Pending + return Poll::Pending; } Poll::Ready(result) => { if result.is_ok() { let _ = this.events.start_send(event); } - return Poll::Ready(()) + return Poll::Ready(()); } } } - State::Done => panic!("`Task::poll()` called after completion.") + State::Done => panic!("`Task::poll()` called after completion."), } } } diff --git a/core/src/connection/pool.rs b/core/src/connection/pool.rs index 263c36a88a8..9925dd526c0 100644 --- a/core/src/connection/pool.rs +++ b/core/src/connection/pool.rs @@ -19,29 +19,15 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - ConnectedPoint, - PeerId, connection::{ self, - Connected, - Connection, - ConnectionId, - ConnectionLimit, - ConnectionError, - ConnectionHandler, - IncomingInfo, - IntoConnectionHandler, - OutgoingInfo, - Substream, - PendingConnectionError, - handler::{ - THandlerInEvent, - THandlerOutEvent, - THandlerError, - }, + handler::{THandlerError, THandlerInEvent, THandlerOutEvent}, manager::{self, Manager, ManagerConfig}, + Connected, Connection, ConnectionError, ConnectionHandler, ConnectionId, ConnectionLimit, + IncomingInfo, IntoConnectionHandler, OutgoingInfo, PendingConnectionError, Substream, }, muxing::StreamMuxer, + ConnectedPoint, PeerId, }; use either::Either; use fnv::FnvHashMap; @@ -76,9 +62,7 @@ pub struct Pool { disconnected: Vec, } -impl fmt::Debug -for Pool -{ +impl fmt::Debug for Pool { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("Pool") .field("counters", &self.counters) @@ -86,8 +70,7 @@ for Pool } } -impl Unpin -for Pool {} +impl Unpin for Pool {} /// Event that can happen on the `Pool`. pub enum PoolEvent<'a, THandler: IntoConnectionHandler, TTransErr> { @@ -157,56 +140,60 @@ pub enum PoolEvent<'a, THandler: IntoConnectionHandler, TTransErr> { }, } -impl<'a, THandler: IntoConnectionHandler, TTransErr> fmt::Debug for PoolEvent<'a, THandler, TTransErr> +impl<'a, THandler: IntoConnectionHandler, TTransErr> fmt::Debug + for PoolEvent<'a, THandler, TTransErr> where TTransErr: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match *self { - PoolEvent::ConnectionEstablished { ref connection, .. } => { - f.debug_tuple("PoolEvent::ConnectionEstablished") - .field(connection) - .finish() - }, - PoolEvent::ConnectionClosed { ref id, ref connected, ref error, .. } => { - f.debug_struct("PoolEvent::ConnectionClosed") - .field("id", id) - .field("connected", connected) - .field("error", error) - .finish() - }, - PoolEvent::PendingConnectionError { ref id, ref error, .. } => { - f.debug_struct("PoolEvent::PendingConnectionError") - .field("id", id) - .field("error", error) - .finish() - }, - PoolEvent::ConnectionEvent { ref connection, ref event } => { - f.debug_struct("PoolEvent::ConnectionEvent") - .field("peer", &connection.peer_id()) - .field("event", event) - .finish() - }, - PoolEvent::AddressChange { ref connection, ref new_endpoint, ref old_endpoint } => { - f.debug_struct("PoolEvent::AddressChange") - .field("peer", &connection.peer_id()) - .field("new_endpoint", new_endpoint) - .field("old_endpoint", old_endpoint) - .finish() - }, + PoolEvent::ConnectionEstablished { ref connection, .. } => f + .debug_tuple("PoolEvent::ConnectionEstablished") + .field(connection) + .finish(), + PoolEvent::ConnectionClosed { + ref id, + ref connected, + ref error, + .. + } => f + .debug_struct("PoolEvent::ConnectionClosed") + .field("id", id) + .field("connected", connected) + .field("error", error) + .finish(), + PoolEvent::PendingConnectionError { + ref id, ref error, .. + } => f + .debug_struct("PoolEvent::PendingConnectionError") + .field("id", id) + .field("error", error) + .finish(), + PoolEvent::ConnectionEvent { + ref connection, + ref event, + } => f + .debug_struct("PoolEvent::ConnectionEvent") + .field("peer", &connection.peer_id()) + .field("event", event) + .finish(), + PoolEvent::AddressChange { + ref connection, + ref new_endpoint, + ref old_endpoint, + } => f + .debug_struct("PoolEvent::AddressChange") + .field("peer", &connection.peer_id()) + .field("new_endpoint", new_endpoint) + .field("old_endpoint", old_endpoint) + .finish(), } } } -impl - Pool -{ +impl Pool { /// Creates a new empty `Pool`. - pub fn new( - local_id: PeerId, - manager_config: ManagerConfig, - limits: ConnectionLimits - ) -> Self { + pub fn new(local_id: PeerId, manager_config: ManagerConfig, limits: ConnectionLimits) -> Self { Pool { local_id, counters: ConnectionCounters::new(limits), @@ -234,13 +221,11 @@ impl info: IncomingInfo<'_>, ) -> Result where - TFut: Future< - Output = Result<(PeerId, TMuxer), PendingConnectionError> - > + Send + 'static, + TFut: Future>> + + Send + + 'static, THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + THandler::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TTransErr: error::Error + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, @@ -263,13 +248,11 @@ impl info: OutgoingInfo<'_>, ) -> Result where - TFut: Future< - Output = Result<(PeerId, TMuxer), PendingConnectionError> - > + Send + 'static, + TFut: Future>> + + Send + + 'static, THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + THandler::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TTransErr: error::Error + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, @@ -290,13 +273,11 @@ impl peer: Option, ) -> ConnectionId where - TFut: Future< - Output = Result<(PeerId, TMuxer), PendingConnectionError> - > + Send + 'static, + TFut: Future>> + + Send + + 'static, THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + THandler::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TTransErr: error::Error + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, @@ -313,12 +294,12 @@ impl move |(peer_id, muxer)| { if let Some(peer) = expected_peer { if peer != peer_id { - return future::err(PendingConnectionError::InvalidPeerId) + return future::err(PendingConnectionError::InvalidPeerId); } } if local_id == peer_id { - return future::err(PendingConnectionError::InvalidPeerId) + return future::err(PendingConnectionError::InvalidPeerId); } let connected = Connected { peer_id, endpoint }; @@ -337,73 +318,80 @@ impl /// Returns the assigned connection ID on success. An error is returned /// if the configured maximum number of established connections for the /// connected peer has been reached. - pub fn add(&mut self, c: Connection, i: Connected) - -> Result + pub fn add( + &mut self, + c: Connection, + i: Connected, + ) -> Result where THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = connection::Substream, - > + Send + 'static, + THandler::Handler: + ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TTransErr: error::Error + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, TMuxer::OutboundSubstream: Send + 'static, { self.counters.check_max_established(&i.endpoint)?; - self.counters.check_max_established_per_peer(self.num_peer_established(&i.peer_id))?; + self.counters + .check_max_established_per_peer(self.num_peer_established(&i.peer_id))?; let id = self.manager.add(c, i.clone()); self.counters.inc_established(&i.endpoint); - self.established.entry(i.peer_id).or_default().insert(id, i.endpoint); + self.established + .entry(i.peer_id) + .or_default() + .insert(id, i.endpoint); Ok(id) } /// Gets an entry representing a connection in the pool. /// /// Returns `None` if the pool has no connection with the given ID. - pub fn get(&mut self, id: ConnectionId) - -> Option>> - { + pub fn get( + &mut self, + id: ConnectionId, + ) -> Option>> { match self.manager.entry(id) { - Some(manager::Entry::Established(entry)) => - Some(PoolConnection::Established(EstablishedConnection { - entry - })), - Some(manager::Entry::Pending(entry)) => + Some(manager::Entry::Established(entry)) => { + Some(PoolConnection::Established(EstablishedConnection { entry })) + } + Some(manager::Entry::Pending(entry)) => { Some(PoolConnection::Pending(PendingConnection { entry, pending: &mut self.pending, counters: &mut self.counters, - })), - None => None + })) + } + None => None, } } /// Gets an established connection from the pool by ID. - pub fn get_established(&mut self, id: ConnectionId) - -> Option>> - { + pub fn get_established( + &mut self, + id: ConnectionId, + ) -> Option>> { match self.get(id) { Some(PoolConnection::Established(c)) => Some(c), - _ => None + _ => None, } } /// Gets a pending outgoing connection by ID. - pub fn get_outgoing(&mut self, id: ConnectionId) - -> Option>> - { + pub fn get_outgoing( + &mut self, + id: ConnectionId, + ) -> Option>> { match self.pending.get(&id) { - Some((ConnectedPoint::Dialer { .. }, _peer)) => - match self.manager.entry(id) { - Some(manager::Entry::Pending(entry)) => - Some(PendingConnection { - entry, - pending: &mut self.pending, - counters: &mut self.counters, - }), - _ => unreachable!("by consistency of `self.pending` with `self.manager`") - } - _ => None + Some((ConnectedPoint::Dialer { .. }, _peer)) => match self.manager.entry(id) { + Some(manager::Entry::Pending(entry)) => Some(PendingConnection { + entry, + pending: &mut self.pending, + counters: &mut self.counters, + }), + _ => unreachable!("by consistency of `self.pending` with `self.manager`"), + }, + _ => None, } } @@ -437,7 +425,9 @@ impl if let Some(manager::Entry::Established(e)) = self.manager.entry(id) { let connected = e.remove(); self.disconnected.push(Disconnected { - id, connected, num_established + id, + connected, + num_established, }); num_established += 1; } @@ -468,14 +458,13 @@ impl } /// Returns an iterator over all established connections of `peer`. - pub fn iter_peer_established<'a>(&'a mut self, peer: &PeerId) - -> EstablishedConnectionIter<'a, - impl Iterator, - THandler, - TTransErr, - > + pub fn iter_peer_established<'a>( + &'a mut self, + peer: &PeerId, + ) -> EstablishedConnectionIter<'a, impl Iterator, THandler, TTransErr> { - let ids = self.iter_peer_established_info(peer) + let ids = self + .iter_peer_established_info(peer) .map(|(id, _endpoint)| *id) .collect::>() .into_iter(); @@ -486,45 +475,50 @@ impl /// Returns an iterator for information on all pending incoming connections. pub fn iter_pending_incoming(&self) -> impl Iterator> { self.iter_pending_info() - .filter_map(|(_, ref endpoint, _)| { - match endpoint { - ConnectedPoint::Listener { local_addr, send_back_addr } => { - Some(IncomingInfo { local_addr, send_back_addr }) - }, - ConnectedPoint::Dialer { .. } => None, - } + .filter_map(|(_, ref endpoint, _)| match endpoint { + ConnectedPoint::Listener { + local_addr, + send_back_addr, + } => Some(IncomingInfo { + local_addr, + send_back_addr, + }), + ConnectedPoint::Dialer { .. } => None, }) } /// Returns an iterator for information on all pending outgoing connections. pub fn iter_pending_outgoing(&self) -> impl Iterator> { self.iter_pending_info() - .filter_map(|(_, ref endpoint, ref peer_id)| { - match endpoint { - ConnectedPoint::Listener { .. } => None, - ConnectedPoint::Dialer { address } => - Some(OutgoingInfo { address, peer_id: peer_id.as_ref() }), - } + .filter_map(|(_, ref endpoint, ref peer_id)| match endpoint { + ConnectedPoint::Listener { .. } => None, + ConnectedPoint::Dialer { address } => Some(OutgoingInfo { + address, + peer_id: peer_id.as_ref(), + }), }) } /// Returns an iterator over all connection IDs and associated endpoints /// of established connections to `peer` known to the pool. - pub fn iter_peer_established_info(&self, peer: &PeerId) - -> impl Iterator + fmt::Debug + '_ - { + pub fn iter_peer_established_info( + &self, + peer: &PeerId, + ) -> impl Iterator + fmt::Debug + '_ { match self.established.get(peer) { Some(conns) => Either::Left(conns.iter()), - None => Either::Right(std::iter::empty()) + None => Either::Right(std::iter::empty()), } } /// Returns an iterator over all pending connection IDs together /// with associated endpoints and expected peer IDs in the pool. - pub fn iter_pending_info(&self) - -> impl Iterator)> + '_ - { - self.pending.iter().map(|(id, (endpoint, info))| (id, endpoint, info)) + pub fn iter_pending_info( + &self, + ) -> impl Iterator)> + '_ { + self.pending + .iter() + .map(|(id, (endpoint, info))| (id, endpoint, info)) } /// Returns an iterator over all connected peers, i.e. those that have @@ -537,9 +531,10 @@ impl /// /// > **Note**: We use a regular `poll` method instead of implementing `Stream`, /// > because we want the `Pool` to stay borrowed if necessary. - pub fn poll<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll< - PoolEvent<'a, THandler, TTransErr> - > { + pub fn poll<'a>( + &'a mut self, + cx: &mut Context<'_>, + ) -> Poll> { // Drain events resulting from forced disconnections. // // Note: The `Disconnected` entries in `self.disconnected` @@ -548,15 +543,18 @@ impl // events in an order that properly counts down `num_established`. // See also `Pool::disconnect`. if let Some(Disconnected { - id, connected, num_established - }) = self.disconnected.pop() { + id, + connected, + num_established, + }) = self.disconnected.pop() + { return Poll::Ready(PoolEvent::ConnectionClosed { id, connected, num_established, error: None, pool: self, - }) + }); } // Poll the connection `Manager`. @@ -576,11 +574,15 @@ impl error, handler: Some(handler), peer, - pool: self - }) + pool: self, + }); } - }, - manager::Event::ConnectionClosed { id, connected, error } => { + } + manager::Event::ConnectionClosed { + id, + connected, + error, + } => { let num_established = if let Some(conns) = self.established.get_mut(&connected.peer_id) { if let Some(endpoint) = conns.remove(&id) { @@ -594,8 +596,12 @@ impl self.established.remove(&connected.peer_id); } return Poll::Ready(PoolEvent::ConnectionClosed { - id, connected, error, num_established, pool: self - }) + id, + connected, + error, + num_established, + pool: self, + }); } manager::Event::ConnectionEstablished { entry } => { let id = entry.id(); @@ -611,12 +617,13 @@ impl error: PendingConnectionError::ConnectionLimit(e), handler: None, peer, - pool: self - }) + pool: self, + }); } // Check per-peer established connection limit. - let current = num_peer_established(&self.established, &entry.connected().peer_id); + let current = + num_peer_established(&self.established, &entry.connected().peer_id); if let Err(e) = self.counters.check_max_established_per_peer(current) { let connected = entry.remove(); return Poll::Ready(PoolEvent::PendingConnectionError { @@ -625,8 +632,8 @@ impl error: PendingConnectionError::ConnectionLimit(e), handler: None, peer, - pool: self - }) + pool: self, + }); } // Peer ID checks must already have happened. See `add_pending`. @@ -644,54 +651,62 @@ impl // Add the connection to the pool. let peer = entry.connected().peer_id; let conns = self.established.entry(peer).or_default(); - let num_established = NonZeroU32::new(u32::try_from(conns.len() + 1).unwrap()) - .expect("n + 1 is always non-zero; qed"); + let num_established = + NonZeroU32::new(u32::try_from(conns.len() + 1).unwrap()) + .expect("n + 1 is always non-zero; qed"); self.counters.inc_established(&endpoint); conns.insert(id, endpoint); match self.get(id) { - Some(PoolConnection::Established(connection)) => + Some(PoolConnection::Established(connection)) => { return Poll::Ready(PoolEvent::ConnectionEstablished { - connection, num_established - }), - _ => unreachable!("since `entry` is an `EstablishedEntry`.") + connection, + num_established, + }) + } + _ => unreachable!("since `entry` is an `EstablishedEntry`."), } } - }, + } manager::Event::ConnectionEvent { entry, event } => { let id = entry.id(); match self.get(id) { - Some(PoolConnection::Established(connection)) => - return Poll::Ready(PoolEvent::ConnectionEvent { - connection, - event, - }), - _ => unreachable!("since `entry` is an `EstablishedEntry`.") + Some(PoolConnection::Established(connection)) => { + return Poll::Ready(PoolEvent::ConnectionEvent { connection, event }) + } + _ => unreachable!("since `entry` is an `EstablishedEntry`."), } - }, - manager::Event::AddressChange { entry, new_endpoint, old_endpoint } => { + } + manager::Event::AddressChange { + entry, + new_endpoint, + old_endpoint, + } => { let id = entry.id(); match self.established.get_mut(&entry.connected().peer_id) { - Some(list) => *list.get_mut(&id) - .expect("state inconsistency: entry is `EstablishedEntry` but absent \ - from `established`") = new_endpoint.clone(), - None => unreachable!("since `entry` is an `EstablishedEntry`.") + Some(list) => { + *list.get_mut(&id).expect( + "state inconsistency: entry is `EstablishedEntry` but absent \ + from `established`", + ) = new_endpoint.clone() + } + None => unreachable!("since `entry` is an `EstablishedEntry`."), }; match self.get(id) { - Some(PoolConnection::Established(connection)) => + Some(PoolConnection::Established(connection)) => { return Poll::Ready(PoolEvent::AddressChange { connection, new_endpoint, old_endpoint, - }), - _ => unreachable!("since `entry` is an `EstablishedEntry`.") + }) + } + _ => unreachable!("since `entry` is an `EstablishedEntry`."), } - }, + } } } } - } /// A connection in a [`Pool`]. @@ -707,9 +722,7 @@ pub struct PendingConnection<'a, TInEvent> { counters: &'a mut ConnectionCounters, } -impl - PendingConnection<'_, TInEvent> -{ +impl PendingConnection<'_, TInEvent> { /// Returns the local connection ID. pub fn id(&self) -> ConnectionId { self.entry.id() @@ -717,17 +730,29 @@ impl /// Returns the (expected) identity of the remote peer, if known. pub fn peer_id(&self) -> &Option { - &self.pending.get(&self.entry.id()).expect("`entry` is a pending entry").1 + &self + .pending + .get(&self.entry.id()) + .expect("`entry` is a pending entry") + .1 } /// Returns information about this endpoint of the connection. pub fn endpoint(&self) -> &ConnectedPoint { - &self.pending.get(&self.entry.id()).expect("`entry` is a pending entry").0 + &self + .pending + .get(&self.entry.id()) + .expect("`entry` is a pending entry") + .0 } /// Aborts the connection attempt, closing the connection. pub fn abort(self) { - let endpoint = self.pending.remove(&self.entry.id()).expect("`entry` is a pending entry").0; + let endpoint = self + .pending + .remove(&self.entry.id()) + .expect("`entry` is a pending entry") + .0; self.counters.dec_pending(&endpoint); self.entry.abort(); } @@ -738,8 +763,7 @@ pub struct EstablishedConnection<'a, TInEvent> { entry: manager::EstablishedEntry<'a, TInEvent>, } -impl fmt::Debug -for EstablishedConnection<'_, TInEvent> +impl fmt::Debug for EstablishedConnection<'_, TInEvent> where TInEvent: fmt::Debug, { @@ -790,7 +814,7 @@ impl EstablishedConnection<'_, TInEvent> { /// /// Returns `Err(())` if the background task associated with the connection /// is terminating and the connection is about to close. - pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll> { + pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll> { self.entry.poll_ready_notify_handler(cx) } @@ -811,21 +835,22 @@ pub struct EstablishedConnectionIter<'a, I, THandler: IntoConnectionHandler, TTr // Note: Ideally this would be an implementation of `Iterator`, but that // requires GATs (cf. https://github.com/rust-lang/rust/issues/44265) and // a different definition of `Iterator`. -impl<'a, I, THandler: IntoConnectionHandler, TTransErr> EstablishedConnectionIter<'a, I, THandler, TTransErr> +impl<'a, I, THandler: IntoConnectionHandler, TTransErr> + EstablishedConnectionIter<'a, I, THandler, TTransErr> where - I: Iterator + I: Iterator, { /// Obtains the next connection, if any. #[allow(clippy::should_implement_trait)] - pub fn next(&mut self) -> Option>> - { + pub fn next(&mut self) -> Option>> { while let Some(id) = self.ids.next() { - if self.pool.manager.is_established(&id) { // (*) + if self.pool.manager.is_established(&id) { + // (*) match self.pool.manager.entry(id) { Some(manager::Entry::Established(entry)) => { return Some(EstablishedConnection { entry }) } - _ => panic!("Established entry not found in manager.") // see (*) + _ => panic!("Established entry not found in manager."), // see (*) } } } @@ -838,17 +863,18 @@ where } /// Returns the first connection, if any, consuming the iterator. - pub fn into_first<'b>(mut self) - -> Option>> - where 'a: 'b + pub fn into_first<'b>(mut self) -> Option>> + where + 'a: 'b, { while let Some(id) = self.ids.next() { - if self.pool.manager.is_established(&id) { // (*) + if self.pool.manager.is_established(&id) { + // (*) match self.pool.manager.entry(id) { Some(manager::Entry::Established(entry)) => { return Some(EstablishedConnection { entry }) } - _ => panic!("Established entry not found in manager.") // see (*) + _ => panic!("Established entry not found in manager."), // see (*) } } } @@ -924,29 +950,45 @@ impl ConnectionCounters { fn inc_pending(&mut self, endpoint: &ConnectedPoint) { match endpoint { - ConnectedPoint::Dialer { .. } => { self.pending_outgoing += 1; } - ConnectedPoint::Listener { .. } => { self.pending_incoming += 1; } + ConnectedPoint::Dialer { .. } => { + self.pending_outgoing += 1; + } + ConnectedPoint::Listener { .. } => { + self.pending_incoming += 1; + } } } fn dec_pending(&mut self, endpoint: &ConnectedPoint) { match endpoint { - ConnectedPoint::Dialer { .. } => { self.pending_outgoing -= 1; } - ConnectedPoint::Listener { .. } => { self.pending_incoming -= 1; } + ConnectedPoint::Dialer { .. } => { + self.pending_outgoing -= 1; + } + ConnectedPoint::Listener { .. } => { + self.pending_incoming -= 1; + } } } fn inc_established(&mut self, endpoint: &ConnectedPoint) { match endpoint { - ConnectedPoint::Dialer { .. } => { self.established_outgoing += 1; } - ConnectedPoint::Listener { .. } => { self.established_incoming += 1; } + ConnectedPoint::Dialer { .. } => { + self.established_outgoing += 1; + } + ConnectedPoint::Listener { .. } => { + self.established_incoming += 1; + } } } fn dec_established(&mut self, endpoint: &ConnectedPoint) { match endpoint { - ConnectedPoint::Dialer { .. } => { self.established_outgoing -= 1; } - ConnectedPoint::Listener { .. } => { self.established_incoming -= 1; } + ConnectedPoint::Dialer { .. } => { + self.established_outgoing -= 1; + } + ConnectedPoint::Listener { .. } => { + self.established_incoming -= 1; + } } } @@ -958,18 +1000,19 @@ impl ConnectionCounters { Self::check(self.pending_incoming, self.limits.max_pending_incoming) } - fn check_max_established(&self, endpoint: &ConnectedPoint) - -> Result<(), ConnectionLimit> - { + fn check_max_established(&self, endpoint: &ConnectedPoint) -> Result<(), ConnectionLimit> { // Check total connection limit. Self::check(self.num_established(), self.limits.max_established_total)?; // Check incoming/outgoing connection limits match endpoint { - ConnectedPoint::Dialer { .. } => - Self::check(self.established_outgoing, self.limits.max_established_outgoing), - ConnectedPoint::Listener { .. } => { - Self::check(self.established_incoming, self.limits.max_established_incoming) - } + ConnectedPoint::Dialer { .. } => Self::check( + self.established_outgoing, + self.limits.max_established_outgoing, + ), + ConnectedPoint::Listener { .. } => Self::check( + self.established_incoming, + self.limits.max_established_incoming, + ), } } @@ -980,22 +1023,21 @@ impl ConnectionCounters { fn check(current: u32, limit: Option) -> Result<(), ConnectionLimit> { if let Some(limit) = limit { if current >= limit { - return Err(ConnectionLimit { limit, current }) + return Err(ConnectionLimit { limit, current }); } } Ok(()) } - } /// Counts the number of established connections to the given peer. fn num_peer_established( established: &FnvHashMap>, - peer: &PeerId + peer: &PeerId, ) -> u32 { - established.get(peer).map_or(0, |conns| - u32::try_from(conns.len()) - .expect("Unexpectedly large number of connections for a peer.")) + established.get(peer).map_or(0, |conns| { + u32::try_from(conns.len()).expect("Unexpectedly large number of connections for a peer.") + }) } /// The configurable connection limits. diff --git a/core/src/connection/substream.rs b/core/src/connection/substream.rs index ac537b488e9..399b09b9f0a 100644 --- a/core/src/connection/substream.rs +++ b/core/src/connection/substream.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::muxing::{StreamMuxer, StreamMuxerEvent, SubstreamRef, substream_from_ref}; +use crate::muxing::{substream_from_ref, StreamMuxer, StreamMuxerEvent, SubstreamRef}; use futures::prelude::*; use multiaddr::Multiaddr; use smallvec::SmallVec; @@ -135,7 +135,9 @@ where #[must_use] pub fn close(mut self) -> (Close, Vec) { let substreams = self.cancel_outgoing(); - let close = Close { muxer: self.inner.clone() }; + let close = Close { + muxer: self.inner.clone(), + }; (close, substreams) } @@ -150,17 +152,19 @@ where } /// Provides an API similar to `Future`. - pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll, IoError>> { + pub fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, IoError>> { // Polling inbound substream. match self.inner.poll_event(cx) { Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(substream))) => { let substream = substream_from_ref(self.inner.clone(), substream); - return Poll::Ready(Ok(SubstreamEvent::InboundSubstream { - substream, - })); + return Poll::Ready(Ok(SubstreamEvent::InboundSubstream { substream })); + } + Poll::Ready(Ok(StreamMuxerEvent::AddressChange(addr))) => { + return Poll::Ready(Ok(SubstreamEvent::AddressChange(addr))) } - Poll::Ready(Ok(StreamMuxerEvent::AddressChange(addr))) => - return Poll::Ready(Ok(SubstreamEvent::AddressChange(addr))), Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), Poll::Pending => {} } @@ -238,8 +242,7 @@ where TMuxer: StreamMuxer, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - f.debug_struct("Close") - .finish() + f.debug_struct("Close").finish() } } @@ -251,22 +254,22 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SubstreamEvent::InboundSubstream { substream } => { - f.debug_struct("SubstreamEvent::OutboundClosed") - .field("substream", substream) - .finish() - }, - SubstreamEvent::OutboundSubstream { user_data, substream } => { - f.debug_struct("SubstreamEvent::OutboundSubstream") - .field("user_data", user_data) - .field("substream", substream) - .finish() - }, - SubstreamEvent::AddressChange(address) => { - f.debug_struct("SubstreamEvent::AddressChange") - .field("address", address) - .finish() - }, + SubstreamEvent::InboundSubstream { substream } => f + .debug_struct("SubstreamEvent::OutboundClosed") + .field("substream", substream) + .finish(), + SubstreamEvent::OutboundSubstream { + user_data, + substream, + } => f + .debug_struct("SubstreamEvent::OutboundSubstream") + .field("user_data", user_data) + .field("substream", substream) + .finish(), + SubstreamEvent::AddressChange(address) => f + .debug_struct("SubstreamEvent::AddressChange") + .field("address", address) + .finish(), } } } diff --git a/core/src/either.rs b/core/src/either.rs index 4d991936121..66a11589f7a 100644 --- a/core/src/either.rs +++ b/core/src/either.rs @@ -20,29 +20,31 @@ use crate::{ muxing::{StreamMuxer, StreamMuxerEvent}, - ProtocolName, - transport::{Transport, ListenerEvent, TransportError}, - Multiaddr + transport::{ListenerEvent, Transport, TransportError}, + Multiaddr, ProtocolName, +}; +use futures::{ + io::{IoSlice, IoSliceMut}, + prelude::*, }; -use futures::{prelude::*, io::{IoSlice, IoSliceMut}}; use pin_project::pin_project; -use std::{fmt, io::{Error as IoError}, pin::Pin, task::Context, task::Poll}; +use std::{fmt, io::Error as IoError, pin::Pin, task::Context, task::Poll}; #[derive(Debug, Copy, Clone)] pub enum EitherError { A(A), - B(B) + B(B), } impl fmt::Display for EitherError where A: fmt::Display, - B: fmt::Display + B: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { EitherError::A(a) => a.fmt(f), - EitherError::B(b) => b.fmt(f) + EitherError::B(b) => b.fmt(f), } } } @@ -50,12 +52,12 @@ where impl std::error::Error for EitherError where A: std::error::Error, - B: std::error::Error + B: std::error::Error, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { EitherError::A(a) => a.source(), - EitherError::B(b) => b.source() + EitherError::B(b) => b.source(), } } } @@ -74,16 +76,22 @@ where A: AsyncRead, B: AsyncRead, { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { match self.project() { EitherOutputProj::First(a) => AsyncRead::poll_read(a, cx, buf), EitherOutputProj::Second(b) => AsyncRead::poll_read(b, cx, buf), } } - fn poll_read_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>]) - -> Poll> - { + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { match self.project() { EitherOutputProj::First(a) => AsyncRead::poll_read_vectored(a, cx, bufs), EitherOutputProj::Second(b) => AsyncRead::poll_read_vectored(b, cx, bufs), @@ -96,16 +104,22 @@ where A: AsyncWrite, B: AsyncWrite, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { match self.project() { EitherOutputProj::First(a) => AsyncWrite::poll_write(a, cx, buf), EitherOutputProj::Second(b) => AsyncWrite::poll_write(b, cx, buf), } } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) - -> Poll> - { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { match self.project() { EitherOutputProj::First(a) => AsyncWrite::poll_write_vectored(a, cx, bufs), EitherOutputProj::Second(b) => AsyncWrite::poll_write_vectored(b, cx, bufs), @@ -136,10 +150,12 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { - EitherOutputProj::First(a) => TryStream::try_poll_next(a, cx) - .map(|v| v.map(|r| r.map_err(EitherError::A))), - EitherOutputProj::Second(b) => TryStream::try_poll_next(b, cx) - .map(|v| v.map(|r| r.map_err(EitherError::B))), + EitherOutputProj::First(a) => { + TryStream::try_poll_next(a, cx).map(|v| v.map(|r| r.map_err(EitherError::A))) + } + EitherOutputProj::Second(b) => { + TryStream::try_poll_next(b, cx).map(|v| v.map(|r| r.map_err(EitherError::B))) + } } } } @@ -189,23 +205,24 @@ where type OutboundSubstream = EitherOutbound; type Error = IoError; - fn poll_event(&self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { match self { EitherOutput::First(inner) => inner.poll_event(cx).map(|result| { - result.map_err(|e| e.into()).map(|event| { - match event { - StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), - StreamMuxerEvent::InboundSubstream(substream) => - StreamMuxerEvent::InboundSubstream(EitherOutput::First(substream)) + result.map_err(|e| e.into()).map(|event| match event { + StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), + StreamMuxerEvent::InboundSubstream(substream) => { + StreamMuxerEvent::InboundSubstream(EitherOutput::First(substream)) } }) }), EitherOutput::Second(inner) => inner.poll_event(cx).map(|result| { - result.map_err(|e| e.into()).map(|event| { - match event { - StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), - StreamMuxerEvent::InboundSubstream(substream) => - StreamMuxerEvent::InboundSubstream(EitherOutput::Second(substream)) + result.map_err(|e| e.into()).map(|event| match event { + StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), + StreamMuxerEvent::InboundSubstream(substream) => { + StreamMuxerEvent::InboundSubstream(EitherOutput::Second(substream)) } }) }), @@ -219,96 +236,112 @@ where } } - fn poll_outbound(&self, cx: &mut Context<'_>, substream: &mut Self::OutboundSubstream) -> Poll> { + fn poll_outbound( + &self, + cx: &mut Context<'_>, + substream: &mut Self::OutboundSubstream, + ) -> Poll> { match (self, substream) { - (EitherOutput::First(ref inner), EitherOutbound::A(ref mut substream)) => { - inner.poll_outbound(cx, substream).map(|p| p.map(EitherOutput::First)).map_err(|e| e.into()) - }, - (EitherOutput::Second(ref inner), EitherOutbound::B(ref mut substream)) => { - inner.poll_outbound(cx, substream).map(|p| p.map(EitherOutput::Second)).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + (EitherOutput::First(ref inner), EitherOutbound::A(ref mut substream)) => inner + .poll_outbound(cx, substream) + .map(|p| p.map(EitherOutput::First)) + .map_err(|e| e.into()), + (EitherOutput::Second(ref inner), EitherOutbound::B(ref mut substream)) => inner + .poll_outbound(cx, substream) + .map(|p| p.map(EitherOutput::Second)) + .map_err(|e| e.into()), + _ => panic!("Wrong API usage"), } } fn destroy_outbound(&self, substream: Self::OutboundSubstream) { match self { - EitherOutput::First(inner) => { - match substream { - EitherOutbound::A(substream) => inner.destroy_outbound(substream), - _ => panic!("Wrong API usage") - } + EitherOutput::First(inner) => match substream { + EitherOutbound::A(substream) => inner.destroy_outbound(substream), + _ => panic!("Wrong API usage"), }, - EitherOutput::Second(inner) => { - match substream { - EitherOutbound::B(substream) => inner.destroy_outbound(substream), - _ => panic!("Wrong API usage") - } + EitherOutput::Second(inner) => match substream { + EitherOutbound::B(substream) => inner.destroy_outbound(substream), + _ => panic!("Wrong API usage"), }, } } - fn read_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + fn read_substream( + &self, + cx: &mut Context<'_>, + sub: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { inner.read_substream(cx, sub, buf).map_err(|e| e.into()) - }, + } (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { inner.read_substream(cx, sub, buf).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + } + _ => panic!("Wrong API usage"), } } - fn write_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream, buf: &[u8]) -> Poll> { + fn write_substream( + &self, + cx: &mut Context<'_>, + sub: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { inner.write_substream(cx, sub, buf).map_err(|e| e.into()) - }, + } (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { inner.write_substream(cx, sub, buf).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + } + _ => panic!("Wrong API usage"), } } - fn flush_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream) -> Poll> { + fn flush_substream( + &self, + cx: &mut Context<'_>, + sub: &mut Self::Substream, + ) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { inner.flush_substream(cx, sub).map_err(|e| e.into()) - }, + } (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { inner.flush_substream(cx, sub).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + } + _ => panic!("Wrong API usage"), } } - fn shutdown_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream) -> Poll> { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + sub: &mut Self::Substream, + ) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { inner.shutdown_substream(cx, sub).map_err(|e| e.into()) - }, + } (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { inner.shutdown_substream(cx, sub).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + } + _ => panic!("Wrong API usage"), } } fn destroy_substream(&self, substream: Self::Substream) { match self { - EitherOutput::First(inner) => { - match substream { - EitherOutput::First(substream) => inner.destroy_substream(substream), - _ => panic!("Wrong API usage") - } + EitherOutput::First(inner) => match substream { + EitherOutput::First(substream) => inner.destroy_substream(substream), + _ => panic!("Wrong API usage"), }, - EitherOutput::Second(inner) => { - match substream { - EitherOutput::Second(substream) => inner.destroy_substream(substream), - _ => panic!("Wrong API usage") - } + EitherOutput::Second(inner) => match substream { + EitherOutput::Second(substream) => inner.destroy_substream(substream), + _ => panic!("Wrong API usage"), }, } } @@ -344,25 +377,33 @@ pub enum EitherListenStream { Second(#[pin] B), } -impl Stream for EitherListenStream +impl Stream + for EitherListenStream where AStream: TryStream, Error = AError>, BStream: TryStream, Error = BError>, { - type Item = Result, EitherError>, EitherError>; + type Item = Result< + ListenerEvent, EitherError>, + EitherError, + >; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { EitherListenStreamProj::First(a) => match TryStream::try_poll_next(a, cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le.map(EitherFuture::First).map_err(EitherError::A)))), + Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le + .map(EitherFuture::First) + .map_err(EitherError::A)))), Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), }, EitherListenStreamProj::Second(a) => match TryStream::try_poll_next(a, cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le.map(EitherFuture::Second).map_err(EitherError::B)))), + Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le + .map(EitherFuture::Second) + .map_err(EitherError::B)))), Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::B(err)))), }, } @@ -388,9 +429,11 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project() { EitherFutureProj::First(a) => TryFuture::try_poll(a, cx) - .map_ok(EitherOutput::First).map_err(EitherError::A), + .map_ok(EitherOutput::First) + .map_err(EitherError::A), EitherFutureProj::Second(a) => TryFuture::try_poll(a, cx) - .map_ok(EitherOutput::Second).map_err(EitherError::B), + .map_ok(EitherOutput::Second) + .map_err(EitherError::B), } } } @@ -398,7 +441,10 @@ where #[pin_project(project = EitherFuture2Proj)] #[derive(Debug, Copy, Clone)] #[must_use = "futures do nothing unless polled"] -pub enum EitherFuture2 { A(#[pin] A), B(#[pin] B) } +pub enum EitherFuture2 { + A(#[pin] A), + B(#[pin] B), +} impl Future for EitherFuture2 where @@ -410,21 +456,26 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project() { EitherFuture2Proj::A(a) => TryFuture::try_poll(a, cx) - .map_ok(EitherOutput::First).map_err(EitherError::A), + .map_ok(EitherOutput::First) + .map_err(EitherError::A), EitherFuture2Proj::B(a) => TryFuture::try_poll(a, cx) - .map_ok(EitherOutput::Second).map_err(EitherError::B), + .map_ok(EitherOutput::Second) + .map_err(EitherError::B), } } } #[derive(Debug, Clone)] -pub enum EitherName { A(A), B(B) } +pub enum EitherName { + A(A), + B(B), +} impl ProtocolName for EitherName { fn protocol_name(&self) -> &[u8] { match self { EitherName::A(a) => a.protocol_name(), - EitherName::B(b) => b.protocol_name() + EitherName::B(b) => b.protocol_name(), } } } diff --git a/core/src/identity.rs b/core/src/identity.rs index 6ed29424045..8c3c83db16e 100644 --- a/core/src/identity.rs +++ b/core/src/identity.rs @@ -41,7 +41,7 @@ pub mod secp256k1; pub mod error; use self::error::*; -use crate::{PeerId, keys_proto}; +use crate::{keys_proto, PeerId}; /// Identity keypair of a node. /// @@ -69,7 +69,7 @@ pub enum Keypair { Rsa(rsa::Keypair), /// A Secp256k1 keypair. #[cfg(feature = "secp256k1")] - Secp256k1(secp256k1::Keypair) + Secp256k1(secp256k1::Keypair), } impl Keypair { @@ -112,7 +112,7 @@ impl Keypair { #[cfg(not(target_arch = "wasm32"))] Rsa(ref pair) => pair.sign(msg), #[cfg(feature = "secp256k1")] - Secp256k1(ref pair) => pair.secret().sign(msg) + Secp256k1(ref pair) => pair.secret().sign(msg), } } @@ -154,7 +154,6 @@ impl Keypair { Ok(pk.encode_to_vec()) } - /// Decode a private key from a protobuf structure and parse it as a [`Keypair`]. pub fn from_protobuf_encoding(bytes: &[u8]) -> Result { use prost::Message; @@ -163,19 +162,20 @@ impl Keypair { .map_err(|e| DecodingError::new("Protobuf").source(e)) .map(zeroize::Zeroizing::new)?; - let key_type = keys_proto::KeyType::from_i32(private_key.r#type) - .ok_or_else(|| DecodingError::new(format!("unknown key type: {}", private_key.r#type)))?; + let key_type = keys_proto::KeyType::from_i32(private_key.r#type).ok_or_else(|| { + DecodingError::new(format!("unknown key type: {}", private_key.r#type)) + })?; match key_type { keys_proto::KeyType::Ed25519 => { ed25519::Keypair::decode(&mut private_key.data).map(Keypair::Ed25519) - }, - keys_proto::KeyType::Rsa => { - Err(DecodingError::new("Decoding RSA key from Protobuf is unsupported.")) - }, - keys_proto::KeyType::Secp256k1 => { - Err(DecodingError::new("Decoding Secp256k1 key from Protobuf is unsupported.")) - }, + } + keys_proto::KeyType::Rsa => Err(DecodingError::new( + "Decoding RSA key from Protobuf is unsupported.", + )), + keys_proto::KeyType::Secp256k1 => Err(DecodingError::new( + "Decoding Secp256k1 key from Protobuf is unsupported.", + )), } } } @@ -197,7 +197,7 @@ pub enum PublicKey { Rsa(rsa::PublicKey), #[cfg(feature = "secp256k1")] /// A public Secp256k1 key. - Secp256k1(secp256k1::PublicKey) + Secp256k1(secp256k1::PublicKey), } impl PublicKey { @@ -212,7 +212,7 @@ impl PublicKey { #[cfg(not(target_arch = "wasm32"))] Rsa(pk) => pk.verify(msg, sig), #[cfg(feature = "secp256k1")] - Secp256k1(pk) => pk.verify(msg, sig) + Secp256k1(pk) => pk.verify(msg, sig), } } @@ -222,27 +222,26 @@ impl PublicKey { use prost::Message; let public_key = match self { - PublicKey::Ed25519(key) => - keys_proto::PublicKey { - r#type: keys_proto::KeyType::Ed25519 as i32, - data: key.encode().to_vec() - }, + PublicKey::Ed25519(key) => keys_proto::PublicKey { + r#type: keys_proto::KeyType::Ed25519 as i32, + data: key.encode().to_vec(), + }, #[cfg(not(target_arch = "wasm32"))] - PublicKey::Rsa(key) => - keys_proto::PublicKey { - r#type: keys_proto::KeyType::Rsa as i32, - data: key.encode_x509() - }, + PublicKey::Rsa(key) => keys_proto::PublicKey { + r#type: keys_proto::KeyType::Rsa as i32, + data: key.encode_x509(), + }, #[cfg(feature = "secp256k1")] - PublicKey::Secp256k1(key) => - keys_proto::PublicKey { - r#type: keys_proto::KeyType::Secp256k1 as i32, - data: key.encode().to_vec() - } + PublicKey::Secp256k1(key) => keys_proto::PublicKey { + r#type: keys_proto::KeyType::Secp256k1 as i32, + data: key.encode().to_vec(), + }, }; let mut buf = Vec::with_capacity(public_key.encoded_len()); - public_key.encode(&mut buf).expect("Vec provides capacity as needed"); + public_key + .encode(&mut buf) + .expect("Vec provides capacity as needed"); buf } @@ -261,7 +260,7 @@ impl PublicKey { match key_type { keys_proto::KeyType::Ed25519 => { ed25519::PublicKey::decode(&pubkey.data).map(PublicKey::Ed25519) - }, + } #[cfg(not(target_arch = "wasm32"))] keys_proto::KeyType::Rsa => { rsa::PublicKey::decode_x509(&pubkey.data).map(PublicKey::Rsa) @@ -270,7 +269,7 @@ impl PublicKey { keys_proto::KeyType::Rsa => { log::debug!("support for RSA was disabled at compile-time"); Err(DecodingError::new("Unsupported")) - }, + } #[cfg(feature = "secp256k1")] keys_proto::KeyType::Secp256k1 => { secp256k1::PublicKey::decode(&pubkey.data).map(PublicKey::Secp256k1) @@ -311,7 +310,8 @@ mod tests { fn keypair_from_protobuf_encoding() { // E.g. retrieved from an IPFS config file. let base_64_encoded = "CAESQL6vdKQuznQosTrW7FWI9At+XX7EBf0BnZLhb6w+N+XSQSdfInl6c7U4NuxXJlhKcRBlBw9d0tj2dfBIVf6mcPA="; - let expected_peer_id = PeerId::from_str("12D3KooWEChVMMMzV8acJ53mJHrw1pQ27UAGkCxWXLJutbeUMvVu").unwrap(); + let expected_peer_id = + PeerId::from_str("12D3KooWEChVMMMzV8acJ53mJHrw1pQ27UAGkCxWXLJutbeUMvVu").unwrap(); let encoded = base64::decode(base_64_encoded).unwrap(); diff --git a/core/src/identity/ed25519.rs b/core/src/identity/ed25519.rs index f606a82b19b..5782ac788cb 100644 --- a/core/src/identity/ed25519.rs +++ b/core/src/identity/ed25519.rs @@ -20,12 +20,12 @@ //! Ed25519 keys. +use super::error::DecodingError; +use core::fmt; use ed25519_dalek::{self as ed25519, Signer as _, Verifier as _}; use rand::RngCore; use std::convert::TryFrom; -use super::error::DecodingError; use zeroize::Zeroize; -use core::fmt; /// An Ed25519 keypair. pub struct Keypair(ed25519::Keypair); @@ -49,7 +49,10 @@ impl Keypair { /// Note that this binary format is the same as `ed25519_dalek`'s and `ed25519_zebra`'s. pub fn decode(kp: &mut [u8]) -> Result { ed25519::Keypair::from_bytes(kp) - .map(|k| { kp.zeroize(); Keypair(k) }) + .map(|k| { + kp.zeroize(); + Keypair(k) + }) .map_err(|e| DecodingError::new("Ed25519 keypair").source(e)) } @@ -72,7 +75,9 @@ impl Keypair { impl fmt::Debug for Keypair { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Keypair").field("public", &self.0.public).finish() + f.debug_struct("Keypair") + .field("public", &self.0.public) + .finish() } } @@ -80,7 +85,8 @@ impl Clone for Keypair { fn clone(&self) -> Keypair { let mut sk_bytes = self.0.secret.to_bytes(); let secret = SecretKey::from_bytes(&mut sk_bytes) - .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k").0; + .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") + .0; let public = ed25519::PublicKey::from_bytes(&self.0.public.to_bytes()) .expect("ed25519::PublicKey::from_bytes(to_bytes(k)) != k"); Keypair(ed25519::Keypair { secret, public }) @@ -99,7 +105,10 @@ impl From for Keypair { fn from(sk: SecretKey) -> Keypair { let secret: ed25519::ExpandedSecretKey = (&sk.0).into(); let public = ed25519::PublicKey::from(&secret); - Keypair(ed25519::Keypair { secret: sk.0, public }) + Keypair(ed25519::Keypair { + secret: sk.0, + public, + }) } } @@ -120,7 +129,9 @@ impl fmt::Debug for PublicKey { impl PublicKey { /// Verify the Ed25519 signature on a message using the public key. pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool { - ed25519::Signature::try_from(sig).and_then(|s| self.0.verify(msg, &s)).is_ok() + ed25519::Signature::try_from(sig) + .and_then(|s| self.0.verify(msg, &s)) + .is_ok() } /// Encode the public key into a byte array in compressed form, i.e. @@ -150,8 +161,7 @@ impl AsRef<[u8]> for SecretKey { impl Clone for SecretKey { fn clone(&self) -> SecretKey { let mut sk_bytes = self.0.to_bytes(); - Self::from_bytes(&mut sk_bytes) - .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") + Self::from_bytes(&mut sk_bytes).expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") } } @@ -166,8 +176,11 @@ impl SecretKey { pub fn generate() -> SecretKey { let mut bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut bytes); - SecretKey(ed25519::SecretKey::from_bytes(&bytes) - .expect("this returns `Err` only if the length is wrong; the length is correct; qed")) + SecretKey( + ed25519::SecretKey::from_bytes(&bytes).expect( + "this returns `Err` only if the length is wrong; the length is correct; qed", + ), + ) } /// Create an Ed25519 secret key from a byte slice, zeroing the input on success. @@ -188,9 +201,7 @@ mod tests { use quickcheck::*; fn eq_keypairs(kp1: &Keypair, kp2: &Keypair) -> bool { - kp1.public() == kp2.public() - && - kp1.0.secret.as_bytes() == kp2.0.secret.as_bytes() + kp1.public() == kp2.public() && kp1.0.secret.as_bytes() == kp2.0.secret.as_bytes() } #[test] @@ -199,9 +210,7 @@ mod tests { let kp1 = Keypair::generate(); let mut kp1_enc = kp1.encode(); let kp2 = Keypair::decode(&mut kp1_enc).unwrap(); - eq_keypairs(&kp1, &kp2) - && - kp1_enc.iter().all(|b| *b == 0) + eq_keypairs(&kp1, &kp2) && kp1_enc.iter().all(|b| *b == 0) } QuickCheck::new().tests(10).quickcheck(prop as fn() -> _); } @@ -212,9 +221,7 @@ mod tests { let kp1 = Keypair::generate(); let mut sk = kp1.0.secret.to_bytes(); let kp2 = Keypair::from(SecretKey::from_bytes(&mut sk).unwrap()); - eq_keypairs(&kp1, &kp2) - && - sk == [0u8; 32] + eq_keypairs(&kp1, &kp2) && sk == [0u8; 32] } QuickCheck::new().tests(10).quickcheck(prop as fn() -> _); } diff --git a/core/src/identity/error.rs b/core/src/identity/error.rs index 8fd1b1b9be9..76f41278d5d 100644 --- a/core/src/identity/error.rs +++ b/core/src/identity/error.rs @@ -27,16 +27,22 @@ use std::fmt; #[derive(Debug)] pub struct DecodingError { msg: String, - source: Option> + source: Option>, } impl DecodingError { pub(crate) fn new(msg: S) -> Self { - Self { msg: msg.to_string(), source: None } + Self { + msg: msg.to_string(), + source: None, + } } pub(crate) fn source(self, source: impl Error + Send + Sync + 'static) -> Self { - Self { source: Some(Box::new(source)), .. self } + Self { + source: Some(Box::new(source)), + ..self + } } } @@ -56,17 +62,23 @@ impl Error for DecodingError { #[derive(Debug)] pub struct SigningError { msg: String, - source: Option> + source: Option>, } /// An error during encoding of key material. impl SigningError { pub(crate) fn new(msg: S) -> Self { - Self { msg: msg.to_string(), source: None } + Self { + msg: msg.to_string(), + source: None, + } } pub(crate) fn source(self, source: impl Error + Send + Sync + 'static) -> Self { - Self { source: Some(Box::new(source)), .. self } + Self { + source: Some(Box::new(source)), + ..self + } } } @@ -81,4 +93,3 @@ impl Error for SigningError { self.source.as_ref().map(|s| &**s as &dyn Error) } } - diff --git a/core/src/identity/rsa.rs b/core/src/identity/rsa.rs index ffbfb975ff0..752bb156764 100644 --- a/core/src/identity/rsa.rs +++ b/core/src/identity/rsa.rs @@ -20,12 +20,12 @@ //! RSA keys. -use asn1_der::typed::{DerEncodable, DerDecodable, DerTypeView, Sequence}; -use asn1_der::{DerObject, Asn1DerError, Asn1DerErrorVariant, Sink, VecBacking}; use super::error::*; +use asn1_der::typed::{DerDecodable, DerEncodable, DerTypeView, Sequence}; +use asn1_der::{Asn1DerError, Asn1DerErrorVariant, DerObject, Sink, VecBacking}; use ring::rand::SystemRandom; -use ring::signature::{self, RsaKeyPair, RSA_PKCS1_SHA256, RSA_PKCS1_2048_8192_SHA256}; use ring::signature::KeyPair; +use ring::signature::{self, RsaKeyPair, RSA_PKCS1_2048_8192_SHA256, RSA_PKCS1_SHA256}; use std::{fmt, sync::Arc}; use zeroize::Zeroize; @@ -56,7 +56,7 @@ impl Keypair { let rng = SystemRandom::new(); match self.0.sign(&RSA_PKCS1_SHA256, &rng, &data, &mut signature) { Ok(()) => Ok(signature), - Err(e) => Err(SigningError::new("RSA").source(e)) + Err(e) => Err(SigningError::new("RSA").source(e)), } } } @@ -89,12 +89,14 @@ impl PublicKey { let spki = Asn1SubjectPublicKeyInfo { algorithmIdentifier: Asn1RsaEncryption { algorithm: Asn1OidRsaEncryption, - parameters: () + parameters: (), }, - subjectPublicKey: Asn1SubjectPublicKey(self.clone()) + subjectPublicKey: Asn1SubjectPublicKey(self.clone()), }; let mut buf = Vec::new(); - let buf = spki.encode(&mut buf).map(|_| buf) + let buf = spki + .encode(&mut buf) + .map(|_| buf) .expect("RSA X.509 public key encoding failed."); buf } @@ -127,7 +129,7 @@ impl fmt::Debug for PublicKey { /// A raw ASN1 OID. #[derive(Copy, Clone)] struct Asn1RawOid<'a> { - object: DerObject<'a> + object: DerObject<'a>, } impl<'a> Asn1RawOid<'a> { @@ -179,7 +181,7 @@ impl Asn1OidRsaEncryption { /// /// [RFC-3279]: https://tools.ietf.org/html/rfc3279#section-2.3.1 /// [RFC-5280]: https://tools.ietf.org/html/rfc5280#section-4.1 - const OID: [u8;9] = [ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 ]; + const OID: [u8; 9] = [0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01]; } impl DerEncodable for Asn1OidRsaEncryption { @@ -194,7 +196,7 @@ impl DerDecodable<'_> for Asn1OidRsaEncryption { oid if oid == Self::OID => Ok(Self), _ => Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData( "DER object is not the 'rsaEncryption' identifier.", - ))) + ))), } } } @@ -202,7 +204,7 @@ impl DerDecodable<'_> for Asn1OidRsaEncryption { /// The ASN.1 AlgorithmIdentifier for "rsaEncryption". struct Asn1RsaEncryption { algorithm: Asn1OidRsaEncryption, - parameters: () + parameters: (), } impl DerEncodable for Asn1RsaEncryption { @@ -211,7 +213,9 @@ impl DerEncodable for Asn1RsaEncryption { let algorithm = self.algorithm.der_object(VecBacking(&mut algorithm_buf))?; let mut parameters_buf = Vec::new(); - let parameters = self.parameters.der_object(VecBacking(&mut parameters_buf))?; + let parameters = self + .parameters + .der_object(VecBacking(&mut parameters_buf))?; Sequence::write(&[algorithm, parameters], sink) } @@ -221,7 +225,7 @@ impl DerDecodable<'_> for Asn1RsaEncryption { fn load(object: DerObject<'_>) -> Result { let seq: Sequence = Sequence::load(object)?; - Ok(Self{ + Ok(Self { algorithm: seq.get_as(0)?, parameters: seq.get_as(1)?, }) @@ -248,9 +252,9 @@ impl DerEncodable for Asn1SubjectPublicKey { impl DerDecodable<'_> for Asn1SubjectPublicKey { fn load(object: DerObject<'_>) -> Result { if object.tag() != 3 { - return Err(Asn1DerError::new( - Asn1DerErrorVariant::InvalidData("DER object tag is not the bit string tag."), - )); + return Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData( + "DER object tag is not the bit string tag.", + ))); } let pk_der: Vec = object.value().into_iter().skip(1).cloned().collect(); @@ -264,13 +268,15 @@ impl DerDecodable<'_> for Asn1SubjectPublicKey { #[allow(non_snake_case)] struct Asn1SubjectPublicKeyInfo { algorithmIdentifier: Asn1RsaEncryption, - subjectPublicKey: Asn1SubjectPublicKey + subjectPublicKey: Asn1SubjectPublicKey, } impl DerEncodable for Asn1SubjectPublicKeyInfo { fn encode(&self, sink: &mut S) -> Result<(), Asn1DerError> { let mut identifier_buf = Vec::new(); - let identifier = self.algorithmIdentifier.der_object(VecBacking(&mut identifier_buf))?; + let identifier = self + .algorithmIdentifier + .der_object(VecBacking(&mut identifier_buf))?; let mut key_buf = Vec::new(); let key = self.subjectPublicKey.der_object(VecBacking(&mut key_buf))?; @@ -340,6 +346,8 @@ mod tests { fn prop(SomeKeypair(kp): SomeKeypair, msg: Vec) -> Result { kp.sign(&msg).map(|s| kp.public().verify(&msg, &s)) } - QuickCheck::new().tests(10).quickcheck(prop as fn(_,_) -> _); + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _) -> _); } } diff --git a/core/src/identity/secp256k1.rs b/core/src/identity/secp256k1.rs index be887064131..2c3aaf89a51 100644 --- a/core/src/identity/secp256k1.rs +++ b/core/src/identity/secp256k1.rs @@ -20,18 +20,18 @@ //! Secp256k1 keys. +use super::error::{DecodingError, SigningError}; use asn1_der::typed::{DerDecodable, Sequence}; -use sha2::{Digest as ShaDigestTrait, Sha256}; +use core::fmt; use libsecp256k1::{Message, Signature}; -use super::error::{DecodingError, SigningError}; +use sha2::{Digest as ShaDigestTrait, Sha256}; use zeroize::Zeroize; -use core::fmt; /// A Secp256k1 keypair. #[derive(Clone)] pub struct Keypair { secret: SecretKey, - public: PublicKey + public: PublicKey, } impl Keypair { @@ -53,7 +53,9 @@ impl Keypair { impl fmt::Debug for Keypair { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Keypair").field("public", &self.public).finish() + f.debug_struct("Keypair") + .field("public", &self.public) + .finish() } } @@ -110,10 +112,11 @@ impl SecretKey { let der_obj = der.as_mut(); let obj: Sequence = DerDecodable::decode(der_obj) .map_err(|e| DecodingError::new("Secp256k1 DER ECPrivateKey").source(e))?; - let sk_obj = obj.get(1) + let sk_obj = obj + .get(1) .map_err(|e| DecodingError::new("Not enough elements in DER").source(e))?; - let mut sk_bytes: Vec = asn1_der::typed::DerDecodable::load(sk_obj) - .map_err(DecodingError::new)?; + let mut sk_bytes: Vec = + asn1_der::typed::DerDecodable::load(sk_obj).map_err(DecodingError::new)?; let sk = SecretKey::from_bytes(&mut sk_bytes)?; sk_bytes.zeroize(); der_obj.zeroize(); @@ -138,7 +141,11 @@ impl SecretKey { pub fn sign_hash(&self, msg: &[u8]) -> Result, SigningError> { let m = Message::parse_slice(msg) .map_err(|_| SigningError::new("failed to parse secp256k1 digest"))?; - Ok(libsecp256k1::sign(&m, &self.0).0.serialize_der().as_ref().into()) + Ok(libsecp256k1::sign(&m, &self.0) + .0 + .serialize_der() + .as_ref() + .into()) } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 844fd2a23bc..60727c52062 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -54,16 +54,16 @@ pub mod network; pub mod transport; pub mod upgrade; +pub use connection::{Connected, ConnectedPoint, Endpoint}; +pub use identity::PublicKey; pub use multiaddr::Multiaddr; pub use multihash; pub use muxing::StreamMuxer; +pub use network::Network; pub use peer_id::PeerId; -pub use identity::PublicKey; -pub use transport::Transport; pub use translation::address_translation; -pub use upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, UpgradeError, ProtocolName}; -pub use connection::{Connected, Endpoint, ConnectedPoint}; -pub use network::Network; +pub use transport::Transport; +pub use upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeError, UpgradeInfo}; use std::{future::Future, pin::Pin}; diff --git a/core/src/muxing.rs b/core/src/muxing.rs index c8ae456fac2..12beb51d9dd 100644 --- a/core/src/muxing.rs +++ b/core/src/muxing.rs @@ -55,7 +55,12 @@ use fnv::FnvHashMap; use futures::{future, prelude::*, task::Context, task::Poll}; use multiaddr::Multiaddr; use parking_lot::Mutex; -use std::{io, ops::Deref, fmt, pin::Pin, sync::atomic::{AtomicUsize, Ordering}}; +use std::{ + fmt, io, + ops::Deref, + pin::Pin, + sync::atomic::{AtomicUsize, Ordering}, +}; pub use self::singleton::SingletonMuxer; @@ -95,7 +100,10 @@ pub trait StreamMuxer { /// work, such as processing incoming packets and polling timers. /// /// An error can be generated if the connection has been closed. - fn poll_event(&self, cx: &mut Context<'_>) -> Poll, Self::Error>>; + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>; /// Opens a new outgoing substream, and produces the equivalent to a future that will be /// resolved when it becomes available. @@ -113,8 +121,11 @@ pub trait StreamMuxer { /// /// May panic or produce an undefined result if an earlier polling of the same substream /// returned `Ready` or `Err`. - fn poll_outbound(&self, cx: &mut Context<'_>, s: &mut Self::OutboundSubstream) - -> Poll>; + fn poll_outbound( + &self, + cx: &mut Context<'_>, + s: &mut Self::OutboundSubstream, + ) -> Poll>; /// Destroys an outbound substream future. Use this after the outbound substream has finished, /// or if you want to interrupt it. @@ -131,8 +142,12 @@ pub trait StreamMuxer { /// /// An error can be generated if the connection has been closed, or if a protocol misbehaviour /// happened. - fn read_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &mut [u8]) - -> Poll>; + fn read_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll>; /// Write data to a substream. The behaviour is the same as `futures::AsyncWrite::poll_write`. /// @@ -145,8 +160,12 @@ pub trait StreamMuxer { /// /// It is incorrect to call this method on a substream if you called `shutdown_substream` on /// this substream earlier. - fn write_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &[u8]) - -> Poll>; + fn write_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &[u8], + ) -> Poll>; /// Flushes a substream. The behaviour is the same as `futures::AsyncWrite::poll_flush`. /// @@ -158,8 +177,11 @@ pub trait StreamMuxer { /// call this method may be notified. /// /// > **Note**: This method may be implemented as a call to `flush_all`. - fn flush_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) - -> Poll>; + fn flush_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll>; /// Attempts to shut down the writing side of a substream. The behaviour is similar to /// `AsyncWrite::poll_close`. @@ -172,8 +194,11 @@ pub trait StreamMuxer { /// /// An error can be generated if the connection has been closed, or if a protocol misbehaviour /// happened. - fn shutdown_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) - -> Poll>; + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll>; /// Destroys a substream. fn destroy_substream(&self, s: Self::Substream); @@ -246,14 +271,12 @@ where P::Target: StreamMuxer, { let muxer2 = muxer.clone(); - future::poll_fn(move |cx| muxer.poll_event(cx)) - .map_ok(|event| { - match event { - StreamMuxerEvent::InboundSubstream(substream) => - StreamMuxerEvent::InboundSubstream(substream_from_ref(muxer2, substream)), - StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), - } - }) + future::poll_fn(move |cx| muxer.poll_event(cx)).map_ok(|event| match event { + StreamMuxerEvent::InboundSubstream(substream) => { + StreamMuxerEvent::InboundSubstream(substream_from_ref(muxer2, substream)) + } + StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), + }) } /// Same as `outbound_from_ref`, but wraps the output in an object that @@ -336,7 +359,8 @@ where // We use a `this` because the compiler isn't smart enough to allow mutably borrowing // multiple different fields from the `Pin` at the same time. let this = &mut *self; - this.muxer.poll_outbound(cx, this.outbound.as_mut().expect("outbound was empty")) + this.muxer + .poll_outbound(cx, this.outbound.as_mut().expect("outbound was empty")) } } @@ -408,7 +432,11 @@ where P: Deref, P::Target: StreamMuxer, { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { // We use a `this` because the compiler isn't smart enough to allow mutably borrowing // multiple different fields from the `Pin` at the same time. let this = &mut *self; @@ -423,7 +451,11 @@ where P: Deref, P::Target: StreamMuxer, { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { // We use a `this` because the compiler isn't smart enough to allow mutably borrowing // multiple different fields from the `Pin` at the same time. let this = &mut *self; @@ -440,20 +472,16 @@ where let s = this.substream.as_mut().expect("substream was empty"); loop { match this.shutdown_state { - ShutdownState::Shutdown => { - match this.muxer.shutdown_substream(cx, s) { - Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Flush, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), - Poll::Pending => return Poll::Pending, - } - } - ShutdownState::Flush => { - match this.muxer.flush_substream(cx, s) { - Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Done, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), - Poll::Pending => return Poll::Pending, - } - } + ShutdownState::Shutdown => match this.muxer.shutdown_substream(cx, s) { + Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Flush, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => return Poll::Pending, + }, + ShutdownState::Flush => match this.muxer.flush_substream(cx, s) { + Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Done, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => return Poll::Pending, + }, ShutdownState::Done => { return Poll::Ready(Ok(())); } @@ -477,13 +505,18 @@ where P::Target: StreamMuxer, { fn drop(&mut self) { - self.muxer.destroy_substream(self.substream.take().expect("substream was empty")) + self.muxer + .destroy_substream(self.substream.take().expect("substream was empty")) } } /// Abstract `StreamMuxer`. pub struct StreamMuxerBox { - inner: Box + Send + Sync>, + inner: Box< + dyn StreamMuxer + + Send + + Sync, + >, } impl StreamMuxerBox { @@ -514,7 +547,10 @@ impl StreamMuxer for StreamMuxerBox { type Error = io::Error; #[inline] - fn poll_event(&self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { self.inner.poll_event(cx) } @@ -524,7 +560,11 @@ impl StreamMuxer for StreamMuxerBox { } #[inline] - fn poll_outbound(&self, cx: &mut Context<'_>, s: &mut Self::OutboundSubstream) -> Poll> { + fn poll_outbound( + &self, + cx: &mut Context<'_>, + s: &mut Self::OutboundSubstream, + ) -> Poll> { self.inner.poll_outbound(cx, s) } @@ -534,22 +574,40 @@ impl StreamMuxer for StreamMuxerBox { } #[inline] - fn read_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + fn read_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { self.inner.read_substream(cx, s, buf) } #[inline] - fn write_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &[u8]) -> Poll> { + fn write_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { self.inner.write_substream(cx, s, buf) } #[inline] - fn flush_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) -> Poll> { + fn flush_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { self.inner.flush_substream(cx, s) } #[inline] - fn shutdown_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) -> Poll> { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { self.inner.shutdown_substream(cx, s) } @@ -569,7 +627,10 @@ impl StreamMuxer for StreamMuxerBox { } } -struct Wrap where T: StreamMuxer { +struct Wrap +where + T: StreamMuxer, +{ inner: T, substreams: Mutex>, next_substream: AtomicUsize, @@ -586,11 +647,15 @@ where type Error = io::Error; #[inline] - fn poll_event(&self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { let substream = match self.inner.poll_event(cx) { Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(StreamMuxerEvent::AddressChange(a))) => - return Poll::Ready(Ok(StreamMuxerEvent::AddressChange(a))), + Poll::Ready(Ok(StreamMuxerEvent::AddressChange(a))) => { + return Poll::Ready(Ok(StreamMuxerEvent::AddressChange(a))) + } Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(s))) => s, Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), }; @@ -615,7 +680,10 @@ where substream: &mut Self::OutboundSubstream, ) -> Poll> { let mut list = self.outbound.lock(); - let substream = match self.inner.poll_outbound(cx, list.get_mut(substream).unwrap()) { + let substream = match self + .inner + .poll_outbound(cx, list.get_mut(substream).unwrap()) + { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(s)) => s, Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), @@ -628,37 +696,65 @@ where #[inline] fn destroy_outbound(&self, substream: Self::OutboundSubstream) { let mut list = self.outbound.lock(); - self.inner.destroy_outbound(list.remove(&substream).unwrap()) + self.inner + .destroy_outbound(list.remove(&substream).unwrap()) } #[inline] - fn read_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + fn read_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { let mut list = self.substreams.lock(); - self.inner.read_substream(cx, list.get_mut(s).unwrap(), buf).map_err(|e| e.into()) + self.inner + .read_substream(cx, list.get_mut(s).unwrap(), buf) + .map_err(|e| e.into()) } #[inline] - fn write_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &[u8]) -> Poll> { + fn write_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { let mut list = self.substreams.lock(); - self.inner.write_substream(cx, list.get_mut(s).unwrap(), buf).map_err(|e| e.into()) + self.inner + .write_substream(cx, list.get_mut(s).unwrap(), buf) + .map_err(|e| e.into()) } #[inline] - fn flush_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) -> Poll> { + fn flush_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { let mut list = self.substreams.lock(); - self.inner.flush_substream(cx, list.get_mut(s).unwrap()).map_err(|e| e.into()) + self.inner + .flush_substream(cx, list.get_mut(s).unwrap()) + .map_err(|e| e.into()) } #[inline] - fn shutdown_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) -> Poll> { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { let mut list = self.substreams.lock(); - self.inner.shutdown_substream(cx, list.get_mut(s).unwrap()).map_err(|e| e.into()) + self.inner + .shutdown_substream(cx, list.get_mut(s).unwrap()) + .map_err(|e| e.into()) } #[inline] fn destroy_substream(&self, substream: Self::Substream) { let mut list = self.substreams.lock(); - self.inner.destroy_substream(list.remove(&substream).unwrap()) + self.inner + .destroy_substream(list.remove(&substream).unwrap()) } #[inline] diff --git a/core/src/muxing/singleton.rs b/core/src/muxing/singleton.rs index 47701f07139..749e9cd673e 100644 --- a/core/src/muxing/singleton.rs +++ b/core/src/muxing/singleton.rs @@ -18,11 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{connection::Endpoint, muxing::{StreamMuxer, StreamMuxerEvent}}; +use crate::{ + connection::Endpoint, + muxing::{StreamMuxer, StreamMuxerEvent}, +}; use futures::prelude::*; use parking_lot::Mutex; -use std::{io, pin::Pin, sync::atomic::{AtomicBool, Ordering}, task::Context, task::Poll}; +use std::{ + io, + pin::Pin, + sync::atomic::{AtomicBool, Ordering}, + task::Context, + task::Poll, +}; /// Implementation of `StreamMuxer` that allows only one substream on top of a connection, /// yielding the connection itself. @@ -65,7 +74,10 @@ where type OutboundSubstream = OutboundSubstream; type Error = io::Error; - fn poll_event(&self, _: &mut Context<'_>) -> Poll, io::Error>> { + fn poll_event( + &self, + _: &mut Context<'_>, + ) -> Poll, io::Error>> { match self.endpoint { Endpoint::Dialer => return Poll::Pending, Endpoint::Listener => {} @@ -82,7 +94,11 @@ where OutboundSubstream {} } - fn poll_outbound(&self, _: &mut Context<'_>, _: &mut Self::OutboundSubstream) -> Poll> { + fn poll_outbound( + &self, + _: &mut Context<'_>, + _: &mut Self::OutboundSubstream, + ) -> Poll> { match self.endpoint { Endpoint::Listener => return Poll::Pending, Endpoint::Dialer => {} @@ -95,27 +111,43 @@ where } } - fn destroy_outbound(&self, _: Self::OutboundSubstream) { - } + fn destroy_outbound(&self, _: Self::OutboundSubstream) {} - fn read_substream(&self, cx: &mut Context<'_>, _: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + fn read_substream( + &self, + cx: &mut Context<'_>, + _: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { AsyncRead::poll_read(Pin::new(&mut *self.inner.lock()), cx, buf) } - fn write_substream(&self, cx: &mut Context<'_>, _: &mut Self::Substream, buf: &[u8]) -> Poll> { + fn write_substream( + &self, + cx: &mut Context<'_>, + _: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { AsyncWrite::poll_write(Pin::new(&mut *self.inner.lock()), cx, buf) } - fn flush_substream(&self, cx: &mut Context<'_>, _: &mut Self::Substream) -> Poll> { + fn flush_substream( + &self, + cx: &mut Context<'_>, + _: &mut Self::Substream, + ) -> Poll> { AsyncWrite::poll_flush(Pin::new(&mut *self.inner.lock()), cx) } - fn shutdown_substream(&self, cx: &mut Context<'_>, _: &mut Self::Substream) -> Poll> { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + _: &mut Self::Substream, + ) -> Poll> { AsyncWrite::poll_close(Pin::new(&mut *self.inner.lock()), cx) } - fn destroy_substream(&self, _: Self::Substream) { - } + fn destroy_substream(&self, _: Self::Substream) {} fn close(&self, cx: &mut Context<'_>) -> Poll> { // The `StreamMuxer` trait requires that `close()` implies `flush_all()`. diff --git a/core/src/network.rs b/core/src/network.rs index c069171c7f3..784c1e01ca7 100644 --- a/core/src/network.rs +++ b/core/src/network.rs @@ -21,45 +21,30 @@ mod event; pub mod peer; -pub use crate::connection::{ConnectionLimits, ConnectionCounters}; -pub use event::{NetworkEvent, IncomingConnection}; +pub use crate::connection::{ConnectionCounters, ConnectionLimits}; +pub use event::{IncomingConnection, NetworkEvent}; pub use peer::Peer; use crate::{ - ConnectedPoint, - Executor, - Multiaddr, - PeerId, connection::{ - ConnectionId, - ConnectionLimit, - ConnectionHandler, - IntoConnectionHandler, - IncomingInfo, - OutgoingInfo, - ListenersEvent, - ListenerId, - ListenersStream, - PendingConnectionError, - Substream, - handler::{ - THandlerInEvent, - THandlerOutEvent, - }, + handler::{THandlerInEvent, THandlerOutEvent}, manager::ManagerConfig, pool::{Pool, PoolEvent}, + ConnectionHandler, ConnectionId, ConnectionLimit, IncomingInfo, IntoConnectionHandler, + ListenerId, ListenersEvent, ListenersStream, OutgoingInfo, PendingConnectionError, + Substream, }, muxing::StreamMuxer, transport::{Transport, TransportError}, + ConnectedPoint, Executor, Multiaddr, PeerId, }; -use fnv::{FnvHashMap}; -use futures::{prelude::*, future}; +use fnv::FnvHashMap; +use futures::{future, prelude::*}; use smallvec::SmallVec; use std::{ collections::hash_map, convert::TryFrom as _, - error, - fmt, + error, fmt, num::NonZeroUsize, pin::Pin, task::{Context, Poll}, @@ -95,8 +80,7 @@ where dialing: FnvHashMap>, } -impl fmt::Debug for - Network +impl fmt::Debug for Network where TTrans: fmt::Debug + Transport, THandler: fmt::Debug + ConnectionHandler, @@ -111,16 +95,14 @@ where } } -impl Unpin for - Network +impl Unpin for Network where TTrans: Transport, THandler: IntoConnectionHandler, { } -impl - Network +impl Network where TTrans: Transport, THandler: IntoConnectionHandler, @@ -131,8 +113,7 @@ where } } -impl - Network +impl Network where TTrans: Transport + Clone, TMuxer: StreamMuxer, @@ -142,11 +123,7 @@ where THandler::Handler: ConnectionHandler> + Send, { /// Creates a new node events stream. - pub fn new( - transport: TTrans, - local_peer_id: PeerId, - config: NetworkConfig, - ) -> Self { + pub fn new(transport: TTrans, local_peer_id: PeerId, config: NetworkConfig) -> Self { Network { local_peer_id, listeners: ListenersStream::new(transport), @@ -161,7 +138,10 @@ where } /// Start listening on the given multiaddress. - pub fn listen_on(&mut self, addr: Multiaddr) -> Result> { + pub fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> { self.listeners.listen_on(addr) } @@ -189,14 +169,14 @@ where /// other than the peer who reported the `observed_addr`. /// /// The translation is transport-specific. See [`Transport::address_translation`]. - pub fn address_translation<'a>(&'a self, observed_addr: &'a Multiaddr) - -> Vec + pub fn address_translation<'a>(&'a self, observed_addr: &'a Multiaddr) -> Vec where TMuxer: 'a, THandler: 'a, { let transport = self.listeners.transport(); - let mut addrs: Vec<_> = self.listen_addrs() + let mut addrs: Vec<_> = self + .listen_addrs() .filter_map(move |server| transport.address_translation(server, observed_addr)) .collect(); @@ -218,8 +198,11 @@ where /// The given `handler` will be used to create the /// [`Connection`](crate::connection::Connection) upon success and the /// connection ID is returned. - pub fn dial(&mut self, address: &Multiaddr, handler: THandler) - -> Result + pub fn dial( + &mut self, + address: &Multiaddr, + handler: THandler, + ) -> Result where TTrans: Transport, TTrans::Error: Send + 'static, @@ -238,21 +221,29 @@ where address: address.clone(), handler, remaining: Vec::new(), - }) + }); } } // The address does not specify an expected peer, so just try to dial it as-is, // accepting any peer ID that the remote identifies as. - let info = OutgoingInfo { address, peer_id: None }; + let info = OutgoingInfo { + address, + peer_id: None, + }; match self.transport().clone().dial(address.clone()) { Ok(f) => { - let f = f.map_err(|err| PendingConnectionError::Transport(TransportError::Other(err))); - self.pool.add_outgoing(f, handler, info).map_err(DialError::ConnectionLimit) + let f = + f.map_err(|err| PendingConnectionError::Transport(TransportError::Other(err))); + self.pool + .add_outgoing(f, handler, info) + .map_err(DialError::ConnectionLimit) } Err(err) => { let f = future::err(PendingConnectionError::Transport(err)); - self.pool.add_outgoing(f, handler, info).map_err(DialError::ConnectionLimit) + self.pool + .add_outgoing(f, handler, info) + .map_err(DialError::ConnectionLimit) } } } @@ -274,14 +265,13 @@ where /// Returns the list of addresses we're currently dialing without knowing the `PeerId` of. pub fn unknown_dials(&self) -> impl Iterator { - self.pool.iter_pending_outgoing() - .filter_map(|info| { - if info.peer_id.is_none() { - Some(info.address) - } else { - None - } - }) + self.pool.iter_pending_outgoing().filter_map(|info| { + if info.peer_id.is_none() { + Some(info.address) + } else { + None + } + }) } /// Returns a list of all connected peers, i.e. peers to whom the `Network` @@ -313,9 +303,7 @@ where } /// Obtains a view of a [`Peer`] with the given ID in the network. - pub fn peer(&mut self, peer_id: PeerId) - -> Peer<'_, TTrans, THandler> - { + pub fn peer(&mut self, peer_id: PeerId) -> Peer<'_, TTrans, THandler> { Peer::new(self, peer_id) } @@ -336,8 +324,9 @@ where TTrans::Error: Send + 'static, TTrans::ListenerUpgrade: Send + 'static, { - let upgrade = connection.upgrade.map_err(|err| - PendingConnectionError::Transport(TransportError::Other(err))); + let upgrade = connection + .upgrade + .map_err(|err| PendingConnectionError::Transport(TransportError::Other(err))); let info = IncomingInfo { local_addr: &connection.local_addr, send_back_addr: &connection.send_back_addr, @@ -346,7 +335,12 @@ where } /// Provides an API similar to `Stream`, except that it cannot error. - pub fn poll<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll, THandlerOutEvent, THandler>> + pub fn poll<'a>( + &'a mut self, + cx: &mut Context<'_>, + ) -> Poll< + NetworkEvent<'a, TTrans, THandlerInEvent, THandlerOutEvent, THandler>, + > where TTrans: Transport, TTrans::Error: Send + 'static, @@ -364,7 +358,7 @@ where listener_id, upgrade, local_addr, - send_back_addr + send_back_addr, }) => { return Poll::Ready(NetworkEvent::IncomingConnection { listener_id, @@ -372,17 +366,37 @@ where upgrade, local_addr, send_back_addr, - } + }, }) } - Poll::Ready(ListenersEvent::NewAddress { listener_id, listen_addr }) => { - return Poll::Ready(NetworkEvent::NewListenerAddress { listener_id, listen_addr }) + Poll::Ready(ListenersEvent::NewAddress { + listener_id, + listen_addr, + }) => { + return Poll::Ready(NetworkEvent::NewListenerAddress { + listener_id, + listen_addr, + }) } - Poll::Ready(ListenersEvent::AddressExpired { listener_id, listen_addr }) => { - return Poll::Ready(NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr }) + Poll::Ready(ListenersEvent::AddressExpired { + listener_id, + listen_addr, + }) => { + return Poll::Ready(NetworkEvent::ExpiredListenerAddress { + listener_id, + listen_addr, + }) } - Poll::Ready(ListenersEvent::Closed { listener_id, addresses, reason }) => { - return Poll::Ready(NetworkEvent::ListenerClosed { listener_id, addresses, reason }) + Poll::Ready(ListenersEvent::Closed { + listener_id, + addresses, + reason, + }) => { + return Poll::Ready(NetworkEvent::ListenerClosed { + listener_id, + addresses, + reason, + }) } Poll::Ready(ListenersEvent::Error { listener_id, error }) => { return Poll::Ready(NetworkEvent::ListenerError { listener_id, error }) @@ -392,7 +406,10 @@ where // Poll the known peers. let event = match self.pool.poll(cx) { Poll::Pending => return Poll::Pending, - Poll::Ready(PoolEvent::ConnectionEstablished { connection, num_established }) => { + Poll::Ready(PoolEvent::ConnectionEstablished { + connection, + num_established, + }) => { if let hash_map::Entry::Occupied(mut e) = self.dialing.entry(connection.peer_id()) { e.get_mut().retain(|s| s.current.0 != connection.id()); if e.get().is_empty() { @@ -405,7 +422,14 @@ where num_established, } } - Poll::Ready(PoolEvent::PendingConnectionError { id, endpoint, error, handler, pool, .. }) => { + Poll::Ready(PoolEvent::PendingConnectionError { + id, + endpoint, + error, + handler, + pool, + .. + }) => { let dialing = &mut self.dialing; let (next, event) = on_connection_failed(dialing, id, endpoint, error, handler); if let Some(dial) = next { @@ -416,35 +440,37 @@ where } event } - Poll::Ready(PoolEvent::ConnectionClosed { id, connected, error, num_established, .. }) => { - NetworkEvent::ConnectionClosed { - id, - connected, - num_established, - error, - } - } + Poll::Ready(PoolEvent::ConnectionClosed { + id, + connected, + error, + num_established, + .. + }) => NetworkEvent::ConnectionClosed { + id, + connected, + num_established, + error, + }, Poll::Ready(PoolEvent::ConnectionEvent { connection, event }) => { - NetworkEvent::ConnectionEvent { - connection, - event, - } - } - Poll::Ready(PoolEvent::AddressChange { connection, new_endpoint, old_endpoint }) => { - NetworkEvent::AddressChange { - connection, - new_endpoint, - old_endpoint, - } + NetworkEvent::ConnectionEvent { connection, event } } + Poll::Ready(PoolEvent::AddressChange { + connection, + new_endpoint, + old_endpoint, + }) => NetworkEvent::AddressChange { + connection, + new_endpoint, + old_endpoint, + }, }; Poll::Ready(event) } /// Initiates a connection attempt to a known peer. - fn dial_peer(&mut self, opts: DialingOpts) - -> Result + fn dial_peer(&mut self, opts: DialingOpts) -> Result where TTrans: Transport, TTrans::Dial: Send + 'static, @@ -452,7 +478,12 @@ where TMuxer: Send + Sync + 'static, TMuxer::OutboundSubstream: Send, { - dial_peer_impl(self.transport().clone(), &mut self.pool, &mut self.dialing, opts) + dial_peer_impl( + self.transport().clone(), + &mut self.pool, + &mut self.dialing, + opts, + ) } } @@ -470,15 +501,13 @@ fn dial_peer_impl( transport: TTrans, pool: &mut Pool, dialing: &mut FnvHashMap>, - opts: DialingOpts + opts: DialingOpts, ) -> Result where THandler: IntoConnectionHandler + Send + 'static, ::Error: error::Error + Send + 'static, ::OutboundOpenInfo: Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + THandler::Handler: ConnectionHandler> + Send + 'static, TTrans: Transport, TTrans::Dial: Send + 'static, TTrans::Error: error::Error + Send + 'static, @@ -493,23 +522,32 @@ where let result = match transport.dial(addr.clone()) { Ok(fut) => { let fut = fut.map_err(|e| PendingConnectionError::Transport(TransportError::Other(e))); - let info = OutgoingInfo { address: &addr, peer_id: Some(&opts.peer) }; - pool.add_outgoing(fut, opts.handler, info).map_err(DialError::ConnectionLimit) - }, + let info = OutgoingInfo { + address: &addr, + peer_id: Some(&opts.peer), + }; + pool.add_outgoing(fut, opts.handler, info) + .map_err(DialError::ConnectionLimit) + } Err(err) => { let fut = future::err(PendingConnectionError::Transport(err)); - let info = OutgoingInfo { address: &addr, peer_id: Some(&opts.peer) }; - pool.add_outgoing(fut, opts.handler, info).map_err(DialError::ConnectionLimit) - }, + let info = OutgoingInfo { + address: &addr, + peer_id: Some(&opts.peer), + }; + pool.add_outgoing(fut, opts.handler, info) + .map_err(DialError::ConnectionLimit) + } }; if let Ok(id) = &result { - dialing.entry(opts.peer).or_default().push( - peer::DialingState { + dialing + .entry(opts.peer) + .or_default() + .push(peer::DialingState { current: (*id, addr), remaining: opts.remaining, - }, - ); + }); } result @@ -526,22 +564,24 @@ fn on_connection_failed<'a, TTrans, THandler>( endpoint: ConnectedPoint, error: PendingConnectionError, handler: Option, -) -> (Option>, NetworkEvent<'a, TTrans, THandlerInEvent, THandlerOutEvent, THandler>) +) -> ( + Option>, + NetworkEvent<'a, TTrans, THandlerInEvent, THandlerOutEvent, THandler>, +) where TTrans: Transport, THandler: IntoConnectionHandler, { // Check if the failed connection is associated with a dialing attempt. - let dialing_failed = dialing.iter_mut() - .find_map(|(peer, attempts)| { - if let Some(pos) = attempts.iter().position(|s| s.current.0 == id) { - let attempt = attempts.remove(pos); - let last = attempts.is_empty(); - Some((*peer, attempt, last)) - } else { - None - } - }); + let dialing_failed = dialing.iter_mut().find_map(|(peer, attempts)| { + if let Some(pos) = attempts.iter().position(|s| s.current.0 == id) { + let attempt = attempts.remove(pos); + let last = attempts.is_empty(); + Some((*peer, attempt, last)) + } else { + None + } + }); if let Some((peer_id, mut attempt, last)) = dialing_failed { if last { @@ -551,47 +591,56 @@ where let num_remain = u32::try_from(attempt.remaining.len()).unwrap(); let failed_addr = attempt.current.1.clone(); - let (opts, attempts_remaining) = - if num_remain > 0 { - if let Some(handler) = handler { - let next_attempt = attempt.remaining.remove(0); - let opts = DialingOpts { - peer: peer_id, - handler, - address: next_attempt, - remaining: attempt.remaining - }; - (Some(opts), num_remain) - } else { - // The error is "fatal" for the dialing attempt, since - // the handler was already consumed. All potential - // remaining connection attempts are thus void. - (None, 0) - } + let (opts, attempts_remaining) = if num_remain > 0 { + if let Some(handler) = handler { + let next_attempt = attempt.remaining.remove(0); + let opts = DialingOpts { + peer: peer_id, + handler, + address: next_attempt, + remaining: attempt.remaining, + }; + (Some(opts), num_remain) } else { + // The error is "fatal" for the dialing attempt, since + // the handler was already consumed. All potential + // remaining connection attempts are thus void. (None, 0) - }; + } + } else { + (None, 0) + }; - (opts, NetworkEvent::DialError { - attempts_remaining, - peer_id, - multiaddr: failed_addr, - error, - }) + ( + opts, + NetworkEvent::DialError { + attempts_remaining, + peer_id, + multiaddr: failed_addr, + error, + }, + ) } else { // A pending incoming connection or outgoing connection to an unknown peer failed. match endpoint { - ConnectedPoint::Dialer { address } => - (None, NetworkEvent::UnknownPeerDialError { + ConnectedPoint::Dialer { address } => ( + None, + NetworkEvent::UnknownPeerDialError { multiaddr: address, error, - }), - ConnectedPoint::Listener { local_addr, send_back_addr } => - (None, NetworkEvent::IncomingConnectionError { + }, + ), + ConnectedPoint::Listener { + local_addr, + send_back_addr, + } => ( + None, + NetworkEvent::IncomingConnectionError { local_addr, send_back_addr, - error - }) + error, + }, + ), } } } @@ -644,7 +693,7 @@ impl NetworkConfig { /// only if no executor has already been configured. pub fn or_else_with_executor(mut self, f: F) -> Self where - F: FnOnce() -> Option> + F: FnOnce() -> Option>, { self.manager_config.executor = self.manager_config.executor.or_else(f); self @@ -693,7 +742,7 @@ impl NetworkConfig { fn p2p_addr(peer: PeerId, addr: Multiaddr) -> Result { if let Some(multiaddr::Protocol::P2p(hash)) = addr.iter().last() { if &hash != peer.as_ref() { - return Err(addr) + return Err(addr); } Ok(addr) } else { @@ -718,7 +767,7 @@ mod tests { struct Dummy; impl Executor for Dummy { - fn exec(&self, _: Pin + Send>>) { } + fn exec(&self, _: Pin + Send>>) {} } #[test] diff --git a/core/src/network/event.rs b/core/src/network/event.rs index 8154bd2087d..7b4158265d9 100644 --- a/core/src/network/event.rs +++ b/core/src/network/event.rs @@ -21,20 +21,12 @@ //! Network events and associated information. use crate::{ - Multiaddr, connection::{ - ConnectionId, - ConnectedPoint, - ConnectionError, - ConnectionHandler, - Connected, - EstablishedConnection, - IntoConnectionHandler, - ListenerId, - PendingConnectionError, + Connected, ConnectedPoint, ConnectionError, ConnectionHandler, ConnectionId, + EstablishedConnection, IntoConnectionHandler, ListenerId, PendingConnectionError, }, transport::Transport, - PeerId + Multiaddr, PeerId, }; use std::{fmt, num::NonZeroU32}; @@ -60,7 +52,7 @@ where /// The listener that errored. listener_id: ListenerId, /// The listener error. - error: TTrans::Error + error: TTrans::Error, }, /// One of the listeners is now listening on an additional address. @@ -68,7 +60,7 @@ where /// The listener that is listening on the new address. listener_id: ListenerId, /// The new address the listener is now also listening on. - listen_addr: Multiaddr + listen_addr: Multiaddr, }, /// One of the listeners is no longer listening on some address. @@ -76,7 +68,7 @@ where /// The listener that is no longer listening on some address. listener_id: ListenerId, /// The expired address. - listen_addr: Multiaddr + listen_addr: Multiaddr, }, /// A new connection arrived on a listener. @@ -177,8 +169,8 @@ where }, } -impl fmt::Debug for - NetworkEvent<'_, TTrans, TInEvent, TOutEvent, THandler> +impl fmt::Debug + for NetworkEvent<'_, TTrans, TInEvent, TOutEvent, THandler> where TInEvent: fmt::Debug, TOutEvent: fmt::Debug, @@ -189,83 +181,101 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - NetworkEvent::NewListenerAddress { listener_id, listen_addr } => { - f.debug_struct("NewListenerAddress") - .field("listener_id", listener_id) - .field("listen_addr", listen_addr) - .finish() - } - NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr } => { - f.debug_struct("ExpiredListenerAddress") - .field("listener_id", listener_id) - .field("listen_addr", listen_addr) - .finish() - } - NetworkEvent::ListenerClosed { listener_id, addresses, reason } => { - f.debug_struct("ListenerClosed") - .field("listener_id", listener_id) - .field("addresses", addresses) - .field("reason", reason) - .finish() - } - NetworkEvent::ListenerError { listener_id, error } => { - f.debug_struct("ListenerError") - .field("listener_id", listener_id) - .field("error", error) - .finish() - } - NetworkEvent::IncomingConnection { connection, .. } => { - f.debug_struct("IncomingConnection") - .field("local_addr", &connection.local_addr) - .field("send_back_addr", &connection.send_back_addr) - .finish() - } - NetworkEvent::IncomingConnectionError { local_addr, send_back_addr, error } => { - f.debug_struct("IncomingConnectionError") - .field("local_addr", local_addr) - .field("send_back_addr", send_back_addr) - .field("error", error) - .finish() - } - NetworkEvent::ConnectionEstablished { connection, .. } => { - f.debug_struct("ConnectionEstablished") - .field("connection", connection) - .finish() - } - NetworkEvent::ConnectionClosed { id, connected, error, .. } => { - f.debug_struct("ConnectionClosed") - .field("id", id) - .field("connected", connected) - .field("error", error) - .finish() - } - NetworkEvent::DialError { attempts_remaining, peer_id, multiaddr, error } => { - f.debug_struct("DialError") - .field("attempts_remaining", attempts_remaining) - .field("peer_id", peer_id) - .field("multiaddr", multiaddr) - .field("error", error) - .finish() - } - NetworkEvent::UnknownPeerDialError { multiaddr, error, .. } => { - f.debug_struct("UnknownPeerDialError") - .field("multiaddr", multiaddr) - .field("error", error) - .finish() - } - NetworkEvent::ConnectionEvent { connection, event } => { - f.debug_struct("ConnectionEvent") - .field("connection", connection) - .field("event", event) - .finish() - } - NetworkEvent::AddressChange { connection, new_endpoint, old_endpoint } => { - f.debug_struct("AddressChange") - .field("connection", connection) - .field("new_endpoint", new_endpoint) - .field("old_endpoint", old_endpoint) - .finish() - } + NetworkEvent::NewListenerAddress { + listener_id, + listen_addr, + } => f + .debug_struct("NewListenerAddress") + .field("listener_id", listener_id) + .field("listen_addr", listen_addr) + .finish(), + NetworkEvent::ExpiredListenerAddress { + listener_id, + listen_addr, + } => f + .debug_struct("ExpiredListenerAddress") + .field("listener_id", listener_id) + .field("listen_addr", listen_addr) + .finish(), + NetworkEvent::ListenerClosed { + listener_id, + addresses, + reason, + } => f + .debug_struct("ListenerClosed") + .field("listener_id", listener_id) + .field("addresses", addresses) + .field("reason", reason) + .finish(), + NetworkEvent::ListenerError { listener_id, error } => f + .debug_struct("ListenerError") + .field("listener_id", listener_id) + .field("error", error) + .finish(), + NetworkEvent::IncomingConnection { connection, .. } => f + .debug_struct("IncomingConnection") + .field("local_addr", &connection.local_addr) + .field("send_back_addr", &connection.send_back_addr) + .finish(), + NetworkEvent::IncomingConnectionError { + local_addr, + send_back_addr, + error, + } => f + .debug_struct("IncomingConnectionError") + .field("local_addr", local_addr) + .field("send_back_addr", send_back_addr) + .field("error", error) + .finish(), + NetworkEvent::ConnectionEstablished { connection, .. } => f + .debug_struct("ConnectionEstablished") + .field("connection", connection) + .finish(), + NetworkEvent::ConnectionClosed { + id, + connected, + error, + .. + } => f + .debug_struct("ConnectionClosed") + .field("id", id) + .field("connected", connected) + .field("error", error) + .finish(), + NetworkEvent::DialError { + attempts_remaining, + peer_id, + multiaddr, + error, + } => f + .debug_struct("DialError") + .field("attempts_remaining", attempts_remaining) + .field("peer_id", peer_id) + .field("multiaddr", multiaddr) + .field("error", error) + .finish(), + NetworkEvent::UnknownPeerDialError { + multiaddr, error, .. + } => f + .debug_struct("UnknownPeerDialError") + .field("multiaddr", multiaddr) + .field("error", error) + .finish(), + NetworkEvent::ConnectionEvent { connection, event } => f + .debug_struct("ConnectionEvent") + .field("connection", connection) + .field("event", event) + .finish(), + NetworkEvent::AddressChange { + connection, + new_endpoint, + old_endpoint, + } => f + .debug_struct("AddressChange") + .field("connection", connection) + .field("new_endpoint", new_endpoint) + .field("old_endpoint", old_endpoint) + .finish(), } } } diff --git a/core/src/network/peer.rs b/core/src/network/peer.rs index 88c96aa0983..ca1b9be7502 100644 --- a/core/src/network/peer.rs +++ b/core/src/network/peer.rs @@ -18,35 +18,18 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use super::{DialError, DialingOpts, Network}; use crate::{ - Multiaddr, - Transport, - StreamMuxer, connection::{ - Connected, - ConnectedPoint, - ConnectionHandler, - Connection, - ConnectionId, - ConnectionLimit, - EstablishedConnection, - EstablishedConnectionIter, - IntoConnectionHandler, - PendingConnection, - Substream, - handler::THandlerInEvent, - pool::Pool, + handler::THandlerInEvent, pool::Pool, Connected, ConnectedPoint, Connection, + ConnectionHandler, ConnectionId, ConnectionLimit, EstablishedConnection, + EstablishedConnectionIter, IntoConnectionHandler, PendingConnection, Substream, }, - PeerId + Multiaddr, PeerId, StreamMuxer, Transport, }; use fnv::FnvHashMap; use smallvec::SmallVec; -use std::{ - collections::hash_map, - error, - fmt, -}; -use super::{Network, DialingOpts, DialError}; +use std::{collections::hash_map, error, fmt}; /// The possible representations of a peer in a [`Network`], as /// seen by the local node. @@ -57,7 +40,7 @@ use super::{Network, DialingOpts, DialError}; pub enum Peer<'a, TTrans, THandler> where TTrans: Transport, - THandler: IntoConnectionHandler + THandler: IntoConnectionHandler, { /// At least one established connection exists to the peer. Connected(ConnectedPeer<'a, TTrans, THandler>), @@ -76,53 +59,33 @@ where Local, } -impl<'a, TTrans, THandler> fmt::Debug for - Peer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> fmt::Debug for Peer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - Peer::Connected(p) => { - f.debug_struct("Connected") - .field("peer", &p) - .finish() - } - Peer::Dialing(p) => { - f.debug_struct("Dialing") - .field("peer", &p) - .finish() - } - Peer::Disconnected(p) => { - f.debug_struct("Disconnected") - .field("peer", &p) - .finish() - } - Peer::Local => { - f.debug_struct("Local") - .finish() - } + Peer::Connected(p) => f.debug_struct("Connected").field("peer", &p).finish(), + Peer::Dialing(p) => f.debug_struct("Dialing").field("peer", &p).finish(), + Peer::Disconnected(p) => f.debug_struct("Disconnected").field("peer", &p).finish(), + Peer::Local => f.debug_struct("Local").finish(), } } } -impl<'a, TTrans, THandler> - Peer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> Peer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, { - pub(super) fn new( - network: &'a mut Network, - peer_id: PeerId - ) -> Self { + pub(super) fn new(network: &'a mut Network, peer_id: PeerId) -> Self { if peer_id == network.local_peer_id { return Peer::Local; } if network.pool.is_connected(&peer_id) { - return Self::connected(network, peer_id) + return Self::connected(network, peer_id); } if network.dialing.get_mut(&peer_id).is_some() { @@ -132,31 +95,20 @@ where Self::disconnected(network, peer_id) } - - fn disconnected( - network: &'a mut Network, - peer_id: PeerId - ) -> Self { + fn disconnected(network: &'a mut Network, peer_id: PeerId) -> Self { Peer::Disconnected(DisconnectedPeer { network, peer_id }) } - fn connected( - network: &'a mut Network, - peer_id: PeerId - ) -> Self { + fn connected(network: &'a mut Network, peer_id: PeerId) -> Self { Peer::Connected(ConnectedPeer { network, peer_id }) } - fn dialing( - network: &'a mut Network, - peer_id: PeerId - ) -> Self { + fn dialing(network: &'a mut Network, peer_id: PeerId) -> Self { Peer::Dialing(DialingPeer { network, peer_id }) } } -impl<'a, TTrans, TMuxer, THandler> - Peer<'a, TTrans, THandler> +impl<'a, TTrans, TMuxer, THandler> Peer<'a, TTrans, THandler> where TTrans: Transport + Clone, TTrans::Error: Send + 'static, @@ -176,7 +128,7 @@ where Peer::Connected(..) => true, Peer::Dialing(peer) => peer.is_connected(), Peer::Disconnected(..) => false, - Peer::Local => false + Peer::Local => false, } } @@ -188,7 +140,7 @@ where Peer::Dialing(_) => true, Peer::Connected(peer) => peer.is_dialing(), Peer::Disconnected(..) => false, - Peer::Local => false + Peer::Local => false, } } @@ -206,11 +158,12 @@ where /// `remaining` addresses are tried in order in subsequent connection /// attempts in the context of the same dialing attempt, if the connection /// attempt to the first address fails. - pub fn dial(self, address: Multiaddr, remaining: I, handler: THandler) - -> Result< - (ConnectionId, DialingPeer<'a, TTrans, THandler>), - DialError - > + pub fn dial( + self, + address: Multiaddr, + remaining: I, + handler: THandler, + ) -> Result<(ConnectionId, DialingPeer<'a, TTrans, THandler>), DialError> where I: IntoIterator, { @@ -218,9 +171,12 @@ where Peer::Connected(p) => (p.peer_id, p.network), Peer::Dialing(p) => (p.peer_id, p.network), Peer::Disconnected(p) => (p.peer_id, p.network), - Peer::Local => return Err(DialError::ConnectionLimit(ConnectionLimit { - current: 0, limit: 0 - })) + Peer::Local => { + return Err(DialError::ConnectionLimit(ConnectionLimit { + current: 0, + limit: 0, + })) + } }; let id = network.dial_peer(DialingOpts { @@ -236,9 +192,7 @@ where /// Converts the peer into a `ConnectedPeer`, if an established connection exists. /// /// Succeeds if the there is at least one established connection to the peer. - pub fn into_connected(self) -> Option< - ConnectedPeer<'a, TTrans, THandler> - > { + pub fn into_connected(self) -> Option> { match self { Peer::Connected(peer) => Some(peer), Peer::Dialing(peer) => peer.into_connected(), @@ -250,22 +204,18 @@ where /// Converts the peer into a `DialingPeer`, if a dialing attempt exists. /// /// Succeeds if the there is at least one pending outgoing connection to the peer. - pub fn into_dialing(self) -> Option< - DialingPeer<'a, TTrans, THandler> - > { + pub fn into_dialing(self) -> Option> { match self { Peer::Dialing(peer) => Some(peer), Peer::Connected(peer) => peer.into_dialing(), Peer::Disconnected(..) => None, - Peer::Local => None + Peer::Local => None, } } /// Converts the peer into a `DisconnectedPeer`, if neither an established connection /// nor a dialing attempt exists. - pub fn into_disconnected(self) -> Option< - DisconnectedPeer<'a, TTrans, THandler> - > { + pub fn into_disconnected(self) -> Option> { match self { Peer::Disconnected(peer) => Some(peer), _ => None, @@ -285,8 +235,7 @@ where peer_id: PeerId, } -impl<'a, TTrans, THandler> - ConnectedPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> ConnectedPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -301,9 +250,10 @@ where } /// Obtains an established connection to the peer by ID. - pub fn connection(&mut self, id: ConnectionId) - -> Option>> - { + pub fn connection( + &mut self, + id: ConnectionId, + ) -> Option>> { self.network.pool.get_established(id) } @@ -321,47 +271,43 @@ where /// Converts this peer into a [`DialingPeer`], if there is an ongoing /// dialing attempt, `None` otherwise. - pub fn into_dialing(self) -> Option< - DialingPeer<'a, TTrans, THandler> - > { + pub fn into_dialing(self) -> Option> { if self.network.dialing.contains_key(&self.peer_id) { - Some(DialingPeer { network: self.network, peer_id: self.peer_id }) + Some(DialingPeer { + network: self.network, + peer_id: self.peer_id, + }) } else { None } } /// Gets an iterator over all established connections to the peer. - pub fn connections(&mut self) -> - EstablishedConnectionIter< - impl Iterator, - THandler, - TTrans::Error, - > + pub fn connections( + &mut self, + ) -> EstablishedConnectionIter, THandler, TTrans::Error> { self.network.pool.iter_peer_established(&self.peer_id) } /// Obtains some established connection to the peer. - pub fn some_connection(&mut self) - -> EstablishedConnection> - { + pub fn some_connection(&mut self) -> EstablishedConnection> { self.connections() .into_first() .expect("By `Peer::new` and the definition of `ConnectedPeer`.") } /// Disconnects from the peer, closing all connections. - pub fn disconnect(self) - -> DisconnectedPeer<'a, TTrans, THandler> - { + pub fn disconnect(self) -> DisconnectedPeer<'a, TTrans, THandler> { self.network.disconnect(&self.peer_id); - DisconnectedPeer { network: self.network, peer_id: self.peer_id } + DisconnectedPeer { + network: self.network, + peer_id: self.peer_id, + } } } -impl<'a, TTrans, THandler> fmt::Debug for - ConnectedPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> fmt::Debug for ConnectedPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -369,7 +315,10 @@ where fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("ConnectedPeer") .field("peer_id", &self.peer_id) - .field("established", &self.network.pool.iter_peer_established_info(&self.peer_id)) + .field( + "established", + &self.network.pool.iter_peer_established_info(&self.peer_id), + ) .field("attempts", &self.network.dialing.get(&self.peer_id)) .finish() } @@ -387,8 +336,7 @@ where peer_id: PeerId, } -impl<'a, TTrans, THandler> - DialingPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> DialingPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -404,11 +352,12 @@ where /// Disconnects from this peer, closing all established connections and /// aborting all dialing attempts. - pub fn disconnect(self) - -> DisconnectedPeer<'a, TTrans, THandler> - { + pub fn disconnect(self) -> DisconnectedPeer<'a, TTrans, THandler> { self.network.disconnect(&self.peer_id); - DisconnectedPeer { network: self.network, peer_id: self.peer_id } + DisconnectedPeer { + network: self.network, + peer_id: self.peer_id, + } } /// Checks whether there is an established connection to the peer. @@ -419,11 +368,12 @@ where } /// Converts the peer into a `ConnectedPeer`, if an established connection exists. - pub fn into_connected(self) - -> Option> - { + pub fn into_connected(self) -> Option> { if self.is_connected() { - Some(ConnectedPeer { peer_id: self.peer_id, network: self.network }) + Some(ConnectedPeer { + peer_id: self.peer_id, + network: self.network, + }) } else { None } @@ -431,13 +381,18 @@ where /// Obtains a dialing attempt to the peer by connection ID of /// the current connection attempt. - pub fn attempt(&mut self, id: ConnectionId) - -> Option>> - { + pub fn attempt( + &mut self, + id: ConnectionId, + ) -> Option>> { if let hash_map::Entry::Occupied(attempts) = self.network.dialing.entry(self.peer_id) { if let Some(pos) = attempts.get().iter().position(|s| s.current.0 == id) { if let Some(inner) = self.network.pool.get_outgoing(id) { - return Some(DialingAttempt { pos, inner, attempts }) + return Some(DialingAttempt { + pos, + inner, + attempts, + }); } } } @@ -445,25 +400,25 @@ where } /// Gets an iterator over all dialing (i.e. pending outgoing) connections to the peer. - pub fn attempts(&mut self) -> DialingAttemptIter<'_, THandler, TTrans::Error> - { - DialingAttemptIter::new(&self.peer_id, &mut self.network.pool, &mut self.network.dialing) + pub fn attempts(&mut self) -> DialingAttemptIter<'_, THandler, TTrans::Error> { + DialingAttemptIter::new( + &self.peer_id, + &mut self.network.pool, + &mut self.network.dialing, + ) } /// Obtains some dialing connection to the peer. /// /// At least one dialing connection is guaranteed to exist on a `DialingPeer`. - pub fn some_attempt(&mut self) - -> DialingAttempt<'_, THandlerInEvent> - { + pub fn some_attempt(&mut self) -> DialingAttempt<'_, THandlerInEvent> { self.attempts() .into_first() .expect("By `Peer::new` and the definition of `DialingPeer`.") } } -impl<'a, TTrans, THandler> fmt::Debug for - DialingPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> fmt::Debug for DialingPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -471,7 +426,10 @@ where fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("DialingPeer") .field("peer_id", &self.peer_id) - .field("established", &self.network.pool.iter_peer_established_info(&self.peer_id)) + .field( + "established", + &self.network.pool.iter_peer_established_info(&self.peer_id), + ) .field("attempts", &self.network.dialing.get(&self.peer_id)) .finish() } @@ -489,8 +447,7 @@ where network: &'a mut Network, } -impl<'a, TTrans, THandler> fmt::Debug for - DisconnectedPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> fmt::Debug for DisconnectedPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -502,8 +459,7 @@ where } } -impl<'a, TTrans, THandler> - DisconnectedPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> DisconnectedPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -529,10 +485,8 @@ where self, connected: Connected, connection: Connection, - ) -> Result< - ConnectedPeer<'a, TTrans, THandler>, - ConnectionLimit - > where + ) -> Result, ConnectionLimit> + where THandler: Send + 'static, TTrans::Error: Send + 'static, THandler::Handler: ConnectionHandler> + Send, @@ -542,10 +496,15 @@ where TMuxer::OutboundSubstream: Send, { if connected.peer_id != self.peer_id { - panic!("Invalid peer ID given: {:?}. Expected: {:?}", connected.peer_id, self.peer_id) + panic!( + "Invalid peer ID given: {:?}. Expected: {:?}", + connected.peer_id, self.peer_id + ) } - self.network.pool.add(connection, connected) + self.network + .pool + .add(connection, connected) .map(move |_id| ConnectedPeer { network: self.network, peer_id: self.peer_id, @@ -575,9 +534,7 @@ pub struct DialingAttempt<'a, TInEvent> { pos: usize, } -impl<'a, TInEvent> - DialingAttempt<'a, TInEvent> -{ +impl<'a, TInEvent> DialingAttempt<'a, TInEvent> { /// Returns the ID of the current connection attempt. pub fn id(&self) -> ConnectionId { self.inner.id() @@ -592,7 +549,7 @@ impl<'a, TInEvent> pub fn address(&self) -> &Multiaddr { match self.inner.endpoint() { ConnectedPoint::Dialer { address } => address, - ConnectedPoint::Listener { .. } => unreachable!("by definition of a `DialingAttempt`.") + ConnectedPoint::Listener { .. } => unreachable!("by definition of a `DialingAttempt`."), } } @@ -640,16 +597,20 @@ pub struct DialingAttemptIter<'a, THandler: IntoConnectionHandler, TTransErr> { // Note: Ideally this would be an implementation of `Iterator`, but that // requires GATs (cf. https://github.com/rust-lang/rust/issues/44265) and // a different definition of `Iterator`. -impl<'a, THandler: IntoConnectionHandler, TTransErr> - DialingAttemptIter<'a, THandler, TTransErr> -{ +impl<'a, THandler: IntoConnectionHandler, TTransErr> DialingAttemptIter<'a, THandler, TTransErr> { fn new( peer_id: &'a PeerId, pool: &'a mut Pool, dialing: &'a mut FnvHashMap>, ) -> Self { let end = dialing.get(peer_id).map_or(0, |conns| conns.len()); - Self { pos: 0, end, pool, dialing, peer_id } + Self { + pos: 0, + end, + pool, + dialing, + peer_id, + } } /// Obtains the next dialing connection, if any. @@ -658,22 +619,29 @@ impl<'a, THandler: IntoConnectionHandler, TTransErr> // If the number of elements reduced, the current `DialingAttempt` has been // aborted and iteration needs to continue from the previous position to // account for the removed element. - let end = self.dialing.get(self.peer_id).map_or(0, |conns| conns.len()); + let end = self + .dialing + .get(self.peer_id) + .map_or(0, |conns| conns.len()); if self.end > end { self.end = end; self.pos -= 1; } if self.pos == self.end { - return None + return None; } if let hash_map::Entry::Occupied(attempts) = self.dialing.entry(*self.peer_id) { let id = attempts.get()[self.pos].current.0; if let Some(inner) = self.pool.get_outgoing(id) { - let conn = DialingAttempt { pos: self.pos, inner, attempts }; + let conn = DialingAttempt { + pos: self.pos, + inner, + attempts, + }; self.pos += 1; - return Some(conn) + return Some(conn); } } @@ -681,18 +649,22 @@ impl<'a, THandler: IntoConnectionHandler, TTransErr> } /// Returns the first connection, if any, consuming the iterator. - pub fn into_first<'b>(self) - -> Option>> - where 'a: 'b + pub fn into_first<'b>(self) -> Option>> + where + 'a: 'b, { if self.pos == self.end { - return None + return None; } if let hash_map::Entry::Occupied(attempts) = self.dialing.entry(*self.peer_id) { let id = attempts.get()[self.pos].current.0; if let Some(inner) = self.pool.get_outgoing(id) { - return Some(DialingAttempt { pos: self.pos, inner, attempts }) + return Some(DialingAttempt { + pos: self.pos, + inner, + attempts, + }); } } diff --git a/core/src/peer_id.rs b/core/src/peer_id.rs index 37d46038243..5a9ae8b0341 100644 --- a/core/src/peer_id.rs +++ b/core/src/peer_id.rs @@ -38,9 +38,7 @@ pub struct PeerId { impl fmt::Debug for PeerId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("PeerId") - .field(&self.to_base58()) - .finish() + f.debug_tuple("PeerId").field(&self.to_base58()).finish() } } @@ -80,9 +78,10 @@ impl PeerId { pub fn from_multihash(multihash: Multihash) -> Result { match Code::try_from(multihash.code()) { Ok(Code::Sha2_256) => Ok(PeerId { multihash }), - Ok(Code::Identity) if multihash.digest().len() <= MAX_INLINE_KEY_LENGTH - => Ok(PeerId { multihash }), - _ => Err(multihash) + Ok(Code::Identity) if multihash.digest().len() <= MAX_INLINE_KEY_LENGTH => { + Ok(PeerId { multihash }) + } + _ => Err(multihash), } } @@ -93,7 +92,7 @@ impl PeerId { let peer_id = rand::thread_rng().gen::<[u8; 32]>(); PeerId { multihash: Multihash::wrap(Code::Identity.into(), &peer_id) - .expect("The digest size is never too large") + .expect("The digest size is never too large"), } } @@ -185,7 +184,7 @@ impl FromStr for PeerId { #[cfg(test)] mod tests { - use crate::{PeerId, identity}; + use crate::{identity, PeerId}; #[test] fn peer_id_is_public_key() { @@ -210,7 +209,7 @@ mod tests { #[test] fn random_peer_id_is_valid() { - for _ in 0 .. 5000 { + for _ in 0..5000 { let peer_id = PeerId::random(); assert_eq!(peer_id, PeerId::from_bytes(&peer_id.to_bytes()).unwrap()); } diff --git a/core/src/transport.rs b/core/src/transport.rs index f6e70c44628..7006c15a810 100644 --- a/core/src/transport.rs +++ b/core/src/transport.rs @@ -94,7 +94,9 @@ pub trait Transport { /// /// If this stream produces an error, it is considered fatal and the listener is killed. It /// is possible to report non-fatal errors by producing a [`ListenerEvent::Error`]. - type Listener: Stream, Self::Error>>; + type Listener: Stream< + Item = Result, Self::Error>, + >; /// A pending [`Output`](Transport::Output) for an inbound connection, /// obtained from the [`Listener`](Transport::Listener) stream. @@ -149,7 +151,7 @@ pub trait Transport { fn map(self, f: F) -> map::Map where Self: Sized, - F: FnOnce(Self::Output, ConnectedPoint) -> O + Clone + F: FnOnce(Self::Output, ConnectedPoint) -> O + Clone, { map::Map::new(self, f) } @@ -158,7 +160,7 @@ pub trait Transport { fn map_err(self, f: F) -> map_err::MapErr where Self: Sized, - F: FnOnce(Self::Error) -> E + Clone + F: FnOnce(Self::Error) -> E + Clone, { map_err::MapErr::new(self, f) } @@ -172,7 +174,7 @@ pub trait Transport { where Self: Sized, U: Transport, - ::Error: 'static + ::Error: 'static, { OrTransport::new(self, other) } @@ -189,7 +191,7 @@ pub trait Transport { Self: Sized, C: FnOnce(Self::Output, ConnectedPoint) -> F + Clone, F: TryFuture, - ::Error: Error + 'static + ::Error: Error + 'static, { and_then::AndThen::new(self, f) } @@ -199,7 +201,7 @@ pub trait Transport { fn upgrade(self, version: upgrade::Version) -> upgrade::Builder where Self: Sized, - Self::Error: 'static + Self::Error: 'static, { upgrade::Builder::new(self, version) } @@ -222,7 +224,7 @@ pub enum ListenerEvent { /// The local address which produced this upgrade. local_addr: Multiaddr, /// The remote address which produced this upgrade. - remote_addr: Multiaddr + remote_addr: Multiaddr, }, /// A [`Multiaddr`] is no longer used for listening. AddressExpired(Multiaddr), @@ -239,9 +241,15 @@ impl ListenerEvent { /// based the the function's result. pub fn map(self, f: impl FnOnce(TUpgr) -> U) -> ListenerEvent { match self { - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => { - ListenerEvent::Upgrade { upgrade: f(upgrade), local_addr, remote_addr } - } + ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + } => ListenerEvent::Upgrade { + upgrade: f(upgrade), + local_addr, + remote_addr, + }, ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), ListenerEvent::Error(e) => ListenerEvent::Error(e), @@ -253,8 +261,15 @@ impl ListenerEvent { /// function's result. pub fn map_err(self, f: impl FnOnce(TErr) -> U) -> ListenerEvent { match self { - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr }, + ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + } => ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + }, ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), ListenerEvent::Error(e) => ListenerEvent::Error(f(e)), @@ -263,7 +278,7 @@ impl ListenerEvent { /// Returns `true` if this is an `Upgrade` listener event. pub fn is_upgrade(&self) -> bool { - matches!(self, ListenerEvent::Upgrade {..}) + matches!(self, ListenerEvent::Upgrade { .. }) } /// Try to turn this listener event into upgrade parts. @@ -271,7 +286,12 @@ impl ListenerEvent { /// Returns `None` if the event is not actually an upgrade, /// otherwise the upgrade and the remote address. pub fn into_upgrade(self) -> Option<(TUpgr, Multiaddr)> { - if let ListenerEvent::Upgrade { upgrade, remote_addr, .. } = self { + if let ListenerEvent::Upgrade { + upgrade, + remote_addr, + .. + } = self + { Some((upgrade, remote_addr)) } else { None @@ -347,25 +367,31 @@ impl TransportError { /// Applies a function to the the error in [`TransportError::Other`]. pub fn map(self, map: impl FnOnce(TErr) -> TNewErr) -> TransportError { match self { - TransportError::MultiaddrNotSupported(addr) => TransportError::MultiaddrNotSupported(addr), + TransportError::MultiaddrNotSupported(addr) => { + TransportError::MultiaddrNotSupported(addr) + } TransportError::Other(err) => TransportError::Other(map(err)), } } } impl fmt::Display for TransportError -where TErr: fmt::Display, +where + TErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TransportError::MultiaddrNotSupported(addr) => write!(f, "Multiaddr is not supported: {}", addr), + TransportError::MultiaddrNotSupported(addr) => { + write!(f, "Multiaddr is not supported: {}", addr) + } TransportError::Other(err) => write!(f, "{}", err), } } } impl Error for TransportError -where TErr: Error + 'static, +where + TErr: Error + 'static, { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { diff --git a/core/src/transport/and_then.rs b/core/src/transport/and_then.rs index 22018729a07..51f5d88c2b6 100644 --- a/core/src/transport/and_then.rs +++ b/core/src/transport/and_then.rs @@ -19,9 +19,9 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - ConnectedPoint, either::EitherError, - transport::{Transport, TransportError, ListenerEvent} + transport::{ListenerEvent, Transport, TransportError}, + ConnectedPoint, }; use futures::{future::Either, prelude::*}; use multiaddr::Multiaddr; @@ -29,7 +29,10 @@ use std::{error, marker::PhantomPinned, pin::Pin, task::Context, task::Poll}; /// See the `Transport::and_then` method. #[derive(Debug, Clone)] -pub struct AndThen { transport: T, fun: C } +pub struct AndThen { + transport: T, + fun: C, +} impl AndThen { pub(crate) fn new(transport: T, fun: C) -> Self { @@ -51,17 +54,26 @@ where type Dial = AndThenFuture; fn listen_on(self, addr: Multiaddr) -> Result> { - let listener = self.transport.listen_on(addr).map_err(|err| err.map(EitherError::A))?; + let listener = self + .transport + .listen_on(addr) + .map_err(|err| err.map(EitherError::A))?; // Try to negotiate the protocol. // Note that failing to negotiate a protocol will never produce a future with an error. // Instead the `stream` will produce `Ok(Err(...))`. // `stream` can only produce an `Err` if `listening_stream` produces an `Err`. - let stream = AndThenStream { stream: listener, fun: self.fun }; + let stream = AndThenStream { + stream: listener, + fun: self.fun, + }; Ok(stream) } fn dial(self, addr: Multiaddr) -> Result> { - let dialed_fut = self.transport.dial(addr.clone()).map_err(|err| err.map(EitherError::A))?; + let dialed_fut = self + .transport + .dial(addr.clone()) + .map_err(|err| err.map(EitherError::A))?; let future = AndThenFuture { inner: Either::Left(Box::pin(dialed_fut)), args: Some((self.fun, ConnectedPoint::Dialer { address: addr })), @@ -83,19 +95,23 @@ where pub struct AndThenStream { #[pin] stream: TListener, - fun: TMap + fun: TMap, } -impl Stream for AndThenStream +impl Stream + for AndThenStream where TListener: TryStream, Error = TTransErr>, TListUpgr: TryFuture, TMap: FnOnce(TTransOut, ConnectedPoint) -> TMapOut + Clone, - TMapOut: TryFuture + TMapOut: TryFuture, { type Item = Result< - ListenerEvent, EitherError>, - EitherError + ListenerEvent< + AndThenFuture, + EitherError, + >, + EitherError, >; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -103,10 +119,14 @@ where match TryStream::try_poll_next(this.stream, cx) { Poll::Ready(Some(Ok(event))) => { let event = match event { - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => { + ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + } => { let point = ConnectedPoint::Listener { local_addr: local_addr.clone(), - send_back_addr: remote_addr.clone() + send_back_addr: remote_addr.clone(), }; ListenerEvent::Upgrade { upgrade: AndThenFuture { @@ -115,7 +135,7 @@ where marker: PhantomPinned, }, local_addr, - remote_addr + remote_addr, } } ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), @@ -127,7 +147,7 @@ where } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending + Poll::Pending => Poll::Pending, } } } @@ -159,7 +179,10 @@ where Poll::Ready(Err(err)) => return Poll::Ready(Err(EitherError::A(err))), Poll::Pending => return Poll::Pending, }; - let (f, a) = self.args.take().expect("AndThenFuture has already finished."); + let (f, a) = self + .args + .take() + .expect("AndThenFuture has already finished."); f(item, a) } Either::Right(future) => { @@ -176,5 +199,4 @@ where } } -impl Unpin for AndThenFuture { -} +impl Unpin for AndThenFuture {} diff --git a/core/src/transport/boxed.rs b/core/src/transport/boxed.rs index 5322b517dbe..001a0c9fdf3 100644 --- a/core/src/transport/boxed.rs +++ b/core/src/transport/boxed.rs @@ -45,7 +45,8 @@ pub struct Boxed { } type Dial = Pin> + Send>>; -type Listener = Pin, io::Error>>> + Send>>; +type Listener = + Pin, io::Error>>> + Send>>; type ListenerUpgrade = Pin> + Send>>; trait Abstract { @@ -64,12 +65,16 @@ where { fn listen_on(&self, addr: Multiaddr) -> Result, TransportError> { let listener = Transport::listen_on(self.clone(), addr).map_err(|e| e.map(box_err))?; - let fut = listener.map_ok(|event| - event.map(|upgrade| { - let up = upgrade.map_err(box_err); - Box::pin(up) as ListenerUpgrade - }).map_err(box_err) - ).map_err(box_err); + let fut = listener + .map_ok(|event| { + event + .map(|upgrade| { + let up = upgrade.map_err(box_err); + Box::pin(up) as ListenerUpgrade + }) + .map_err(box_err) + }) + .map_err(box_err); Ok(Box::pin(fut)) } diff --git a/core/src/transport/choice.rs b/core/src/transport/choice.rs index 3488b06884d..e9545617f09 100644 --- a/core/src/transport/choice.rs +++ b/core/src/transport/choice.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::either::{EitherListenStream, EitherOutput, EitherError, EitherFuture}; +use crate::either::{EitherError, EitherFuture, EitherListenStream, EitherOutput}; use crate::transport::{Transport, TransportError}; use multiaddr::Multiaddr; @@ -47,13 +47,17 @@ where let addr = match self.0.listen_on(addr) { Ok(listener) => return Ok(EitherListenStream::First(listener)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => return Err(TransportError::Other(EitherError::A(err))), + Err(TransportError::Other(err)) => { + return Err(TransportError::Other(EitherError::A(err))) + } }; let addr = match self.1.listen_on(addr) { Ok(listener) => return Ok(EitherListenStream::Second(listener)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => return Err(TransportError::Other(EitherError::B(err))), + Err(TransportError::Other(err)) => { + return Err(TransportError::Other(EitherError::B(err))) + } }; Err(TransportError::MultiaddrNotSupported(addr)) @@ -63,13 +67,17 @@ where let addr = match self.0.dial(addr) { Ok(connec) => return Ok(EitherFuture::First(connec)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => return Err(TransportError::Other(EitherError::A(err))), + Err(TransportError::Other(err)) => { + return Err(TransportError::Other(EitherError::A(err))) + } }; let addr = match self.1.dial(addr) { Ok(connec) => return Ok(EitherFuture::Second(connec)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => return Err(TransportError::Other(EitherError::B(err))), + Err(TransportError::Other(err)) => { + return Err(TransportError::Other(EitherError::B(err))) + } }; Err(TransportError::MultiaddrNotSupported(addr)) diff --git a/core/src/transport/dummy.rs b/core/src/transport/dummy.rs index 5839a6a5928..a4eaa14901d 100644 --- a/core/src/transport/dummy.rs +++ b/core/src/transport/dummy.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{Transport, TransportError, ListenerEvent}; +use crate::transport::{ListenerEvent, Transport, TransportError}; use crate::Multiaddr; use futures::{prelude::*, task::Context, task::Poll}; use std::{fmt, io, marker::PhantomData, pin::Pin}; @@ -56,7 +56,9 @@ impl Clone for DummyTransport { impl Transport for DummyTransport { type Output = TOut; type Error = io::Error; - type Listener = futures::stream::Pending, Self::Error>>; + type Listener = futures::stream::Pending< + Result, Self::Error>, + >; type ListenerUpgrade = futures::future::Pending>; type Dial = futures::future::Pending>; @@ -83,29 +85,29 @@ impl fmt::Debug for DummyStream { } impl AsyncRead for DummyStream { - fn poll_read(self: Pin<&mut Self>, _: &mut Context<'_>, _: &mut [u8]) - -> Poll> - { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut [u8], + ) -> Poll> { Poll::Ready(Err(io::ErrorKind::Other.into())) } } impl AsyncWrite for DummyStream { - fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) - -> Poll> - { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &[u8], + ) -> Poll> { Poll::Ready(Err(io::ErrorKind::Other.into())) } - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) - -> Poll> - { + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Err(io::ErrorKind::Other.into())) } - fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) - -> Poll> - { + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Err(io::ErrorKind::Other.into())) } } diff --git a/core/src/transport/map.rs b/core/src/transport/map.rs index 0305af6626d..4493507c1d9 100644 --- a/core/src/transport/map.rs +++ b/core/src/transport/map.rs @@ -19,8 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{ + transport::{ListenerEvent, Transport, TransportError}, ConnectedPoint, - transport::{Transport, TransportError, ListenerEvent} }; use futures::prelude::*; use multiaddr::Multiaddr; @@ -28,7 +28,10 @@ use std::{pin::Pin, task::Context, task::Poll}; /// See `Transport::map`. #[derive(Debug, Copy, Clone)] -pub struct Map { transport: T, fun: F } +pub struct Map { + transport: T, + fun: F, +} impl Map { pub(crate) fn new(transport: T, fun: F) -> Self { @@ -39,7 +42,7 @@ impl Map { impl Transport for Map where T: Transport, - F: FnOnce(T::Output, ConnectedPoint) -> D + Clone + F: FnOnce(T::Output, ConnectedPoint) -> D + Clone, { type Output = D; type Error = T::Error; @@ -49,13 +52,19 @@ where fn listen_on(self, addr: Multiaddr) -> Result> { let stream = self.transport.listen_on(addr)?; - Ok(MapStream { stream, fun: self.fun }) + Ok(MapStream { + stream, + fun: self.fun, + }) } fn dial(self, addr: Multiaddr) -> Result> { let future = self.transport.dial(addr.clone())?; let p = ConnectedPoint::Dialer { address: addr }; - Ok(MapFuture { inner: future, args: Some((self.fun, p)) }) + Ok(MapFuture { + inner: future, + args: Some((self.fun, p)), + }) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { @@ -68,13 +77,17 @@ where /// Maps a function over every stream item. #[pin_project::pin_project] #[derive(Clone, Debug)] -pub struct MapStream { #[pin] stream: T, fun: F } +pub struct MapStream { + #[pin] + stream: T, + fun: F, +} impl Stream for MapStream where T: TryStream, Error = E>, X: TryFuture, - F: FnOnce(A, ConnectedPoint) -> B + Clone + F: FnOnce(A, ConnectedPoint) -> B + Clone, { type Item = Result, E>, E>; @@ -83,18 +96,22 @@ where match TryStream::try_poll_next(this.stream, cx) { Poll::Ready(Some(Ok(event))) => { let event = match event { - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => { + ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + } => { let point = ConnectedPoint::Listener { local_addr: local_addr.clone(), - send_back_addr: remote_addr.clone() + send_back_addr: remote_addr.clone(), }; ListenerEvent::Upgrade { upgrade: MapFuture { inner: upgrade, - args: Some((this.fun.clone(), point)) + args: Some((this.fun.clone(), point)), }, local_addr, - remote_addr + remote_addr, } } ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), @@ -105,7 +122,7 @@ where } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending + Poll::Pending => Poll::Pending, } } } @@ -118,13 +135,13 @@ where pub struct MapFuture { #[pin] inner: T, - args: Option<(F, ConnectedPoint)> + args: Option<(F, ConnectedPoint)>, } impl Future for MapFuture where T: TryFuture, - F: FnOnce(A, ConnectedPoint) -> B + F: FnOnce(A, ConnectedPoint) -> B, { type Output = Result; diff --git a/core/src/transport/map_err.rs b/core/src/transport/map_err.rs index c0be6485204..df26214435a 100644 --- a/core/src/transport/map_err.rs +++ b/core/src/transport/map_err.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{Transport, TransportError, ListenerEvent}; +use crate::transport::{ListenerEvent, Transport, TransportError}; use futures::prelude::*; use multiaddr::Multiaddr; use std::{error, pin::Pin, task::Context, task::Poll}; @@ -53,14 +53,17 @@ where let map = self.map; match self.transport.listen_on(addr) { Ok(stream) => Ok(MapErrListener { inner: stream, map }), - Err(err) => Err(err.map(map)) + Err(err) => Err(err.map(map)), } } fn dial(self, addr: Multiaddr) -> Result> { let map = self.map; match self.transport.dial(addr) { - Ok(future) => Ok(MapErrDial { inner: future, map: Some(map) }), + Ok(future) => Ok(MapErrDial { + inner: future, + map: Some(map), + }), Err(err) => Err(err.map(map)), } } @@ -92,11 +95,9 @@ where Poll::Ready(Some(Ok(event))) => { let map = &*this.map; let event = event - .map(move |value| { - MapErrListenerUpgrade { - inner: value, - map: Some(map.clone()) - } + .map(move |value| MapErrListenerUpgrade { + inner: value, + map: Some(map.clone()), }) .map_err(|err| (map.clone())(err)); Poll::Ready(Some(Ok(event))) @@ -117,7 +118,8 @@ pub struct MapErrListenerUpgrade { } impl Future for MapErrListenerUpgrade -where T: Transport, +where + T: Transport, F: FnOnce(T::Error) -> TErr, { type Output = Result; diff --git a/core/src/transport/memory.rs b/core/src/transport/memory.rs index 043dcee06b7..3b4706c9adb 100644 --- a/core/src/transport/memory.rs +++ b/core/src/transport/memory.rs @@ -18,11 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{Transport, transport::{TransportError, ListenerEvent}}; +use crate::{ + transport::{ListenerEvent, TransportError}, + Transport, +}; use fnv::FnvHashMap; -use futures::{future::{self, Ready}, prelude::*, channel::mpsc, task::Context, task::Poll}; +use futures::{ + channel::mpsc, + future::{self, Ready}, + prelude::*, + task::Context, + task::Poll, +}; use lazy_static::lazy_static; -use multiaddr::{Protocol, Multiaddr}; +use multiaddr::{Multiaddr, Protocol}; use parking_lot::Mutex; use rw_stream_sink::RwStreamSink; use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin}; @@ -66,7 +75,7 @@ impl Hub { let (tx, rx) = mpsc::channel(2); match hub.entry(port) { Entry::Occupied(_) => return None, - Entry::Vacant(e) => e.insert(tx) + Entry::Vacant(e) => e.insert(tx), }; Some((rx, port)) @@ -103,7 +112,8 @@ impl DialFuture { fn new(port: NonZeroU64) -> Option { let sender = HUB.get(&port)?; - let (_dial_port_channel, dial_port) = HUB.register_port(0) + let (_dial_port_channel, dial_port) = HUB + .register_port(0) .expect("there to be some random unoccupied port."); let (a_tx, a_rx) = mpsc::channel(4096); @@ -129,14 +139,15 @@ impl Future for DialFuture { type Output = Result>, MemoryTransportError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.sender.poll_ready(cx) { Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)), } - let channel_to_send = self.channel_to_send.take() + let channel_to_send = self + .channel_to_send + .take() .expect("Future should not be polled again once complete"); let dial_port = self.dial_port; match self.sender.start_send((channel_to_send, dial_port)) { @@ -144,8 +155,10 @@ impl Future for DialFuture { Ok(()) => {} } - Poll::Ready(Ok(self.channel_to_return.take() - .expect("Future should not be polled again once complete"))) + Poll::Ready(Ok(self + .channel_to_return + .take() + .expect("Future should not be polled again once complete"))) } } @@ -172,7 +185,7 @@ impl Transport for MemoryTransport { port, addr: Protocol::Memory(port.get()).into(), receiver: rx, - tell_listen_addr: true + tell_listen_addr: true, }; Ok(listener) @@ -226,16 +239,19 @@ pub struct Listener { /// Receives incoming connections. receiver: ChannelReceiver, /// Generate `ListenerEvent::NewAddress` to inform about our listen address. - tell_listen_addr: bool + tell_listen_addr: bool, } impl Stream for Listener { - type Item = Result>, MemoryTransportError>>, MemoryTransportError>, MemoryTransportError>; + type Item = Result< + ListenerEvent>, MemoryTransportError>>, MemoryTransportError>, + MemoryTransportError, + >; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.tell_listen_addr { self.tell_listen_addr = false; - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))) + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))); } let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) { @@ -247,7 +263,7 @@ impl Stream for Listener { let event = ListenerEvent::Upgrade { upgrade: future::ready(Ok(channel)), local_addr: self.addr.clone(), - remote_addr: Protocol::Memory(dial_port.get()).into() + remote_addr: Protocol::Memory(dial_port.get()).into(), }; Poll::Ready(Some(Ok(event))) @@ -267,9 +283,9 @@ fn parse_memory_addr(a: &Multiaddr) -> Result { match protocols.next() { Some(Protocol::Memory(port)) => match protocols.next() { None | Some(Protocol::P2p(_)) => Ok(port), - _ => Err(()) - } - _ => Err(()) + _ => Err(()), + }, + _ => Err(()), } } @@ -294,8 +310,7 @@ pub struct Chan> { dial_port: Option, } -impl Unpin for Chan { -} +impl Unpin for Chan {} impl Stream for Chan { type Item = Result; @@ -313,12 +328,15 @@ impl Sink for Chan { type Error = io::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.outgoing.poll_ready(cx) + self.outgoing + .poll_ready(cx) .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into())) } fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - self.outgoing.start_send(item).map_err(|_| io::ErrorKind::BrokenPipe.into()) + self.outgoing + .start_send(item) + .map_err(|_| io::ErrorKind::BrokenPipe.into()) } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { @@ -354,30 +372,59 @@ mod tests { assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5)); assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(())); assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0)); - assert_eq!(parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()), Err(())); - assert_eq!(parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()), Err(())); - assert_eq!(parse_memory_addr(&"/memory/1234567890".parse().unwrap()), Ok(1_234_567_890)); + assert_eq!( + parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()), + Err(()) + ); + assert_eq!( + parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()), + Err(()) + ); + assert_eq!( + parse_memory_addr(&"/memory/1234567890".parse().unwrap()), + Ok(1_234_567_890) + ); } #[test] fn listening_twice() { let transport = MemoryTransport::default(); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok()); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok()); - let _listener = transport.listen_on("/memory/1639174018481".parse().unwrap()).unwrap(); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_err()); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_err()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_ok()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_ok()); + let _listener = transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .unwrap(); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_err()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_err()); drop(_listener); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok()); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_ok()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_ok()); } #[test] fn port_not_in_use() { let transport = MemoryTransport::default(); - assert!(transport.dial("/memory/810172461024613".parse().unwrap()).is_err()); - let _listener = transport.listen_on("/memory/810172461024613".parse().unwrap()).unwrap(); - assert!(transport.dial("/memory/810172461024613".parse().unwrap()).is_ok()); + assert!(transport + .dial("/memory/810172461024613".parse().unwrap()) + .is_err()); + let _listener = transport + .listen_on("/memory/810172461024613".parse().unwrap()) + .unwrap(); + assert!(transport + .dial("/memory/810172461024613".parse().unwrap()) + .is_ok()); } #[test] @@ -395,9 +442,11 @@ mod tests { let listener = async move { let listener = t1.listen_on(t1_addr.clone()).unwrap(); - let upgrade = listener.filter_map(|ev| futures::future::ready( - ListenerEvent::into_upgrade(ev.unwrap()) - )).next().await.unwrap(); + let upgrade = listener + .filter_map(|ev| futures::future::ready(ListenerEvent::into_upgrade(ev.unwrap()))) + .next() + .await + .unwrap(); let mut socket = upgrade.0.await.unwrap(); @@ -422,16 +471,14 @@ mod tests { #[test] fn dialer_address_unequal_to_listener_address() { - let listener_addr: Multiaddr = Protocol::Memory( - rand::random::().saturating_add(1), - ).into(); + let listener_addr: Multiaddr = + Protocol::Memory(rand::random::().saturating_add(1)).into(); let listener_addr_cloned = listener_addr.clone(); let listener_transport = MemoryTransport::default(); let listener = async move { - let mut listener = listener_transport.listen_on(listener_addr.clone()) - .unwrap(); + let mut listener = listener_transport.listen_on(listener_addr.clone()).unwrap(); while let Some(ev) = listener.next().await { if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { assert!( @@ -444,7 +491,8 @@ mod tests { }; let dialer = async move { - MemoryTransport::default().dial(listener_addr_cloned) + MemoryTransport::default() + .dial(listener_addr_cloned) .unwrap() .await .unwrap(); @@ -458,21 +506,18 @@ mod tests { let (terminate, should_terminate) = futures::channel::oneshot::channel(); let (terminated, is_terminated) = futures::channel::oneshot::channel(); - let listener_addr: Multiaddr = Protocol::Memory( - rand::random::().saturating_add(1), - ).into(); + let listener_addr: Multiaddr = + Protocol::Memory(rand::random::().saturating_add(1)).into(); let listener_addr_cloned = listener_addr.clone(); let listener_transport = MemoryTransport::default(); let listener = async move { - let mut listener = listener_transport.listen_on(listener_addr.clone()) - .unwrap(); + let mut listener = listener_transport.listen_on(listener_addr.clone()).unwrap(); while let Some(ev) = listener.next().await { if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { - let dialer_port = NonZeroU64::new( - parse_memory_addr(&remote_addr).unwrap(), - ).unwrap(); + let dialer_port = + NonZeroU64::new(parse_memory_addr(&remote_addr).unwrap()).unwrap(); assert!( HUB.get(&dialer_port).is_some(), @@ -493,7 +538,8 @@ mod tests { }; let dialer = async move { - let _chan = MemoryTransport::default().dial(listener_addr_cloned) + let _chan = MemoryTransport::default() + .dial(listener_addr_cloned) .unwrap() .await .unwrap(); diff --git a/core/src/transport/timeout.rs b/core/src/transport/timeout.rs index d55d007df08..8084dcb7521 100644 --- a/core/src/transport/timeout.rs +++ b/core/src/transport/timeout.rs @@ -24,7 +24,10 @@ //! underlying `Transport`. // TODO: add example -use crate::{Multiaddr, Transport, transport::{TransportError, ListenerEvent}}; +use crate::{ + transport::{ListenerEvent, TransportError}, + Multiaddr, Transport, +}; use futures::prelude::*; use futures_timer::Delay; use std::{error, fmt, io, pin::Pin, task::Context, task::Poll, time::Duration}; @@ -82,7 +85,9 @@ where type Dial = Timeout; fn listen_on(self, addr: Multiaddr) -> Result> { - let listener = self.inner.listen_on(addr) + let listener = self + .inner + .listen_on(addr) .map_err(|err| err.map(TransportTimeoutError::Other))?; let listener = TimeoutListener { @@ -94,7 +99,9 @@ where } fn dial(self, addr: Multiaddr) -> Result> { - let dial = self.inner.dial(addr) + let dial = self + .inner + .dial(addr) .map_err(|err| err.map(TransportTimeoutError::Other))?; Ok(Timeout { inner: dial, @@ -120,13 +127,16 @@ impl Stream for TimeoutListener where InnerStream: TryStream, Error = E>, { - type Item = Result, TransportTimeoutError>, TransportTimeoutError>; + type Item = + Result, TransportTimeoutError>, TransportTimeoutError>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); let poll_out = match TryStream::try_poll_next(this.inner, cx) { - Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(TransportTimeoutError::Other(err)))), + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(Some(Err(TransportTimeoutError::Other(err)))) + } Poll::Ready(Some(Ok(v))) => v, Poll::Ready(None) => return Poll::Ready(None), Poll::Pending => return Poll::Pending, @@ -134,11 +144,9 @@ where let timeout = *this.timeout; let event = poll_out - .map(move |inner_fut| { - Timeout { - inner: inner_fut, - timer: Delay::new(timeout), - } + .map(move |inner_fut| Timeout { + inner: inner_fut, + timer: Delay::new(timeout), }) .map_err(TransportTimeoutError::Other); @@ -173,14 +181,14 @@ where let mut this = self.project(); match TryFuture::try_poll(this.inner, cx) { - Poll::Pending => {}, + Poll::Pending => {} Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)), Poll::Ready(Err(err)) => return Poll::Ready(Err(TransportTimeoutError::Other(err))), } match Pin::new(&mut this.timer).poll(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(()) => Poll::Ready(Err(TransportTimeoutError::Timeout)) + Poll::Ready(()) => Poll::Ready(Err(TransportTimeoutError::Timeout)), } } } @@ -197,7 +205,8 @@ pub enum TransportTimeoutError { } impl fmt::Display for TransportTimeoutError -where TErr: fmt::Display, +where + TErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -209,7 +218,8 @@ where TErr: fmt::Display, } impl error::Error for TransportTimeoutError -where TErr: error::Error + 'static, +where + TErr: error::Error + 'static, { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index b2cb7b46804..7777be9256e 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -23,28 +23,16 @@ pub use crate::upgrade::Version; use crate::{ - ConnectedPoint, - Negotiated, + muxing::{StreamMuxer, StreamMuxerBox}, transport::{ - Transport, + and_then::AndThen, boxed::boxed, timeout::TransportTimeout, ListenerEvent, Transport, TransportError, - ListenerEvent, - and_then::AndThen, - boxed::boxed, - timeout::TransportTimeout, }, - muxing::{StreamMuxer, StreamMuxerBox}, upgrade::{ - self, - OutboundUpgrade, - InboundUpgrade, - apply_inbound, - apply_outbound, - UpgradeError, - OutboundUpgradeApply, - InboundUpgradeApply + self, apply_inbound, apply_outbound, InboundUpgrade, InboundUpgradeApply, OutboundUpgrade, + OutboundUpgradeApply, UpgradeError, }, - PeerId + ConnectedPoint, Negotiated, PeerId, }; use futures::{prelude::*, ready}; use multiaddr::Multiaddr; @@ -53,7 +41,7 @@ use std::{ fmt, pin::Pin, task::{Context, Poll}, - time::Duration + time::Duration, }; /// A `Builder` facilitates upgrading of a [`Transport`] for use with @@ -105,9 +93,11 @@ where /// /// * I/O upgrade: `C -> (PeerId, D)`. /// * Transport output: `C -> (PeerId, D)` - pub fn authenticate(self, upgrade: U) -> Authenticated< - AndThen Authenticate + Clone> - > where + pub fn authenticate( + self, + upgrade: U, + ) -> Authenticated Authenticate + Clone>> + where T: Transport, C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, @@ -116,11 +106,12 @@ where E: Error + 'static, { let version = self.version; - Authenticated(Builder::new(self.inner.and_then(move |conn, endpoint| { - Authenticate { - inner: upgrade::apply(conn, upgrade, endpoint, version) - } - }), version)) + Authenticated(Builder::new( + self.inner.and_then(move |conn, endpoint| Authenticate { + inner: upgrade::apply(conn, upgrade, endpoint, version), + }), + version, + )) } } @@ -132,19 +123,21 @@ where pub struct Authenticate where C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade> + U: InboundUpgrade> + OutboundUpgrade>, { #[pin] - inner: EitherUpgrade + inner: EitherUpgrade, } impl Future for Authenticate where C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade, - Output = >>::Output, - Error = >>::Error - > + U: InboundUpgrade> + + OutboundUpgrade< + Negotiated, + Output = >>::Output, + Error = >>::Error, + >, { type Output = as Future>::Output; @@ -173,7 +166,7 @@ impl Future for Multiplex where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = M, Error = E>, - U: OutboundUpgrade, Output = M, Error = E> + U: OutboundUpgrade, Output = M, Error = E>, { type Output = Result<(PeerId, M), UpgradeError>; @@ -183,7 +176,10 @@ where Ok(m) => m, Err(err) => return Poll::Ready(Err(err)), }; - let i = this.peer_id.take().expect("Multiplex future polled after completion."); + let i = this + .peer_id + .take() + .expect("Multiplex future polled after completion."); Poll::Ready(Ok((i, m))) } } @@ -195,7 +191,7 @@ pub struct Authenticated(Builder); impl Authenticated where T: Transport, - T::Error: 'static + T::Error: 'static, { /// Applies an arbitrary upgrade. /// @@ -216,7 +212,10 @@ where U: OutboundUpgrade, Output = D, Error = E> + Clone, E: Error + 'static, { - Authenticated(Builder::new(Upgrade::new(self.0.inner, upgrade), self.0.version)) + Authenticated(Builder::new( + Upgrade::new(self.0.inner, upgrade), + self.0.version, + )) } /// Upgrades the transport with a (sub)stream multiplexer. @@ -229,9 +228,11 @@ where /// /// * I/O upgrade: `C -> M`. /// * Transport output: `(PeerId, C) -> (PeerId, M)`. - pub fn multiplex(self, upgrade: U) -> Multiplexed< - AndThen Multiplex + Clone> - > where + pub fn multiplex( + self, + upgrade: U, + ) -> Multiplexed Multiplex + Clone>> + where T: Transport, C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, @@ -242,7 +243,10 @@ where let version = self.0.version; Multiplexed(self.0.inner.and_then(move |(i, c), endpoint| { let upgrade = upgrade::apply(c, upgrade, endpoint, version); - Multiplex { peer_id: Some(i), upgrade } + Multiplex { + peer_id: Some(i), + upgrade, + } })) } @@ -257,21 +261,26 @@ where /// /// * I/O upgrade: `C -> M`. /// * Transport output: `(PeerId, C) -> (PeerId, M)`. - pub fn multiplex_ext(self, up: F) -> Multiplexed< - AndThen Multiplex + Clone> - > where + pub fn multiplex_ext( + self, + up: F, + ) -> Multiplexed Multiplex + Clone>> + where T: Transport, C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, U: InboundUpgrade, Output = M, Error = E>, U: OutboundUpgrade, Output = M, Error = E> + Clone, E: Error + 'static, - F: for<'a> FnOnce(&'a PeerId, &'a ConnectedPoint) -> U + Clone + F: for<'a> FnOnce(&'a PeerId, &'a ConnectedPoint) -> U + Clone, { let version = self.0.version; Multiplexed(self.0.inner.and_then(move |(peer_id, c), endpoint| { let upgrade = upgrade::apply(c, up(&peer_id, &endpoint), endpoint, version); - Multiplex { peer_id: Some(peer_id), upgrade } + Multiplex { + peer_id: Some(peer_id), + upgrade, + } })) } } @@ -293,7 +302,7 @@ impl Multiplexed { T::Error: Send + Sync, M: StreamMuxer + Send + Sync + 'static, M::Substream: Send + 'static, - M::OutboundSubstream: Send + 'static + M::OutboundSubstream: Send + 'static, { boxed(self.map(|(i, m), _| (i, StreamMuxerBox::new(m)))) } @@ -347,7 +356,10 @@ type EitherUpgrade = future::Either, OutboundUpg /// /// See [`Transport::upgrade`] #[derive(Debug, Copy, Clone)] -pub struct Upgrade { inner: T, upgrade: U } +pub struct Upgrade { + inner: T, + upgrade: U, +} impl Upgrade { pub fn new(inner: T, upgrade: U) -> Self { @@ -362,7 +374,7 @@ where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D, Error = E>, U: OutboundUpgrade, Output = D, Error = E> + Clone, - E: Error + 'static + E: Error + 'static, { type Output = (PeerId, D); type Error = TransportUpgradeError; @@ -371,20 +383,24 @@ where type Dial = DialUpgradeFuture; fn dial(self, addr: Multiaddr) -> Result> { - let future = self.inner.dial(addr) + let future = self + .inner + .dial(addr) .map_err(|err| err.map(TransportUpgradeError::Transport))?; Ok(DialUpgradeFuture { future: Box::pin(future), - upgrade: future::Either::Left(Some(self.upgrade)) + upgrade: future::Either::Left(Some(self.upgrade)), }) } fn listen_on(self, addr: Multiaddr) -> Result> { - let stream = self.inner.listen_on(addr) + let stream = self + .inner + .listen_on(addr) .map_err(|err| err.map(TransportUpgradeError::Transport))?; Ok(ListenerStream { stream: Box::pin(stream), - upgrade: self.upgrade + upgrade: self.upgrade, }) } @@ -435,7 +451,7 @@ where C: AsyncRead + AsyncWrite + Unpin, { future: Pin>, - upgrade: future::Either, (Option, OutboundUpgradeApply)> + upgrade: future::Either, (Option, OutboundUpgradeApply)>, } impl Future for DialUpgradeFuture @@ -443,7 +459,7 @@ where F: TryFuture, C: AsyncRead + AsyncWrite + Unpin, U: OutboundUpgrade, Output = D>, - U::Error: Error + U::Error: Error, { type Output = Result<(PeerId, D), TransportUpgradeError>; @@ -455,20 +471,28 @@ where loop { this.upgrade = match this.upgrade { future::Either::Left(ref mut up) => { - let (i, c) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx).map_err(TransportUpgradeError::Transport)) { + let (i, c) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx) + .map_err(TransportUpgradeError::Transport)) + { Ok(v) => v, Err(err) => return Poll::Ready(Err(err)), }; - let u = up.take().expect("DialUpgradeFuture is constructed with Either::Left(Some)."); + let u = up + .take() + .expect("DialUpgradeFuture is constructed with Either::Left(Some)."); future::Either::Right((Some(i), apply_outbound(c, u, upgrade::Version::V1))) } future::Either::Right((ref mut i, ref mut up)) => { - let d = match ready!(Future::poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { + let d = match ready!( + Future::poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade) + ) { Ok(d) => d, Err(err) => return Poll::Ready(Err(err)), }; - let i = i.take().expect("DialUpgradeFuture polled after completion."); - return Poll::Ready(Ok((i, d))) + let i = i + .take() + .expect("DialUpgradeFuture polled after completion."); + return Poll::Ready(Ok((i, d))); } } } @@ -485,7 +509,7 @@ where /// The [`Transport::Listener`] stream of an [`Upgrade`]d transport. pub struct ListenerStream { stream: Pin>, - upgrade: U + upgrade: U, } impl Stream for ListenerStream @@ -493,42 +517,40 @@ where S: TryStream, Error = E>, F: TryFuture, C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade, Output = D> + Clone + U: InboundUpgrade, Output = D> + Clone, { - type Item = Result, TransportUpgradeError>, TransportUpgradeError>; + type Item = Result< + ListenerEvent, TransportUpgradeError>, + TransportUpgradeError, + >; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match ready!(TryStream::try_poll_next(self.stream.as_mut(), cx)) { Some(Ok(event)) => { let event = event - .map(move |future| { - ListenerUpgradeFuture { - future: Box::pin(future), - upgrade: future::Either::Left(Some(self.upgrade.clone())) - } + .map(move |future| ListenerUpgradeFuture { + future: Box::pin(future), + upgrade: future::Either::Left(Some(self.upgrade.clone())), }) .map_err(TransportUpgradeError::Transport); Poll::Ready(Some(Ok(event))) } - Some(Err(err)) => { - Poll::Ready(Some(Err(TransportUpgradeError::Transport(err)))) - } - None => Poll::Ready(None) + Some(Err(err)) => Poll::Ready(Some(Err(TransportUpgradeError::Transport(err)))), + None => Poll::Ready(None), } } } -impl Unpin for ListenerStream { -} +impl Unpin for ListenerStream {} /// The [`Transport::ListenerUpgrade`] future of an [`Upgrade`]d transport. pub struct ListenerUpgradeFuture where C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + U: InboundUpgrade>, { future: Pin>, - upgrade: future::Either, (Option, InboundUpgradeApply)> + upgrade: future::Either, (Option, InboundUpgradeApply)>, } impl Future for ListenerUpgradeFuture @@ -536,7 +558,7 @@ where F: TryFuture, C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D>, - U::Error: Error + U::Error: Error, { type Output = Result<(PeerId, D), TransportUpgradeError>; @@ -548,20 +570,28 @@ where loop { this.upgrade = match this.upgrade { future::Either::Left(ref mut up) => { - let (i, c) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx).map_err(TransportUpgradeError::Transport)) { + let (i, c) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx) + .map_err(TransportUpgradeError::Transport)) + { Ok(v) => v, - Err(err) => return Poll::Ready(Err(err)) + Err(err) => return Poll::Ready(Err(err)), }; - let u = up.take().expect("ListenerUpgradeFuture is constructed with Either::Left(Some)."); + let u = up + .take() + .expect("ListenerUpgradeFuture is constructed with Either::Left(Some)."); future::Either::Right((Some(i), apply_inbound(c, u))) } future::Either::Right((ref mut i, ref mut up)) => { - let d = match ready!(TryFuture::try_poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { + let d = match ready!(TryFuture::try_poll(Pin::new(up), cx) + .map_err(TransportUpgradeError::Upgrade)) + { Ok(v) => v, - Err(err) => return Poll::Ready(Err(err)) + Err(err) => return Poll::Ready(Err(err)), }; - let i = i.take().expect("ListenerUpgradeFuture polled after completion."); - return Poll::Ready(Ok((i, d))) + let i = i + .take() + .expect("ListenerUpgradeFuture polled after completion."); + return Poll::Ready(Ok((i, d))); } } } @@ -571,6 +601,6 @@ where impl Unpin for ListenerUpgradeFuture where C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + U: InboundUpgrade>, { } diff --git a/core/src/upgrade.rs b/core/src/upgrade.rs index 376cbfc1d2b..fbee321a83d 100644 --- a/core/src/upgrade.rs +++ b/core/src/upgrade.rs @@ -69,21 +69,21 @@ mod transfer; use futures::future::Future; -pub use crate::Negotiated; -pub use multistream_select::{Version, NegotiatedComplete, NegotiationError, ProtocolError}; +#[allow(deprecated)] +pub use self::transfer::ReadOneError; pub use self::{ apply::{apply, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply}, denied::DeniedUpgrade, either::EitherUpgrade, error::UpgradeError, from_fn::{from_fn, FromFnUpgrade}, - map::{MapInboundUpgrade, MapOutboundUpgrade, MapInboundUpgradeErr, MapOutboundUpgradeErr}, + map::{MapInboundUpgrade, MapInboundUpgradeErr, MapOutboundUpgrade, MapOutboundUpgradeErr}, optional::OptionalUpgrade, select::SelectUpgrade, - transfer::{write_length_prefixed, write_varint, read_length_prefixed, read_varint}, + transfer::{read_length_prefixed, read_varint, write_length_prefixed, write_varint}, }; -#[allow(deprecated)] -pub use self::transfer::ReadOneError; +pub use crate::Negotiated; +pub use multistream_select::{NegotiatedComplete, NegotiationError, ProtocolError, Version}; /// Types serving as protocol names. /// @@ -167,7 +167,7 @@ pub trait InboundUpgradeExt: InboundUpgrade { fn map_inbound(self, f: F) -> MapInboundUpgrade where Self: Sized, - F: FnOnce(Self::Output) -> T + F: FnOnce(Self::Output) -> T, { MapInboundUpgrade::new(self, f) } @@ -176,7 +176,7 @@ pub trait InboundUpgradeExt: InboundUpgrade { fn map_inbound_err(self, f: F) -> MapInboundUpgradeErr where Self: Sized, - F: FnOnce(Self::Error) -> T + F: FnOnce(Self::Error) -> T, { MapInboundUpgradeErr::new(self, f) } @@ -207,7 +207,7 @@ pub trait OutboundUpgradeExt: OutboundUpgrade { fn map_outbound(self, f: F) -> MapOutboundUpgrade where Self: Sized, - F: FnOnce(Self::Output) -> T + F: FnOnce(Self::Output) -> T, { MapOutboundUpgrade::new(self, f) } @@ -216,11 +216,10 @@ pub trait OutboundUpgradeExt: OutboundUpgrade { fn map_outbound_err(self, f: F) -> MapOutboundUpgradeErr where Self: Sized, - F: FnOnce(Self::Error) -> T + F: FnOnce(Self::Error) -> T, { MapOutboundUpgradeErr::new(self, f) } } impl> OutboundUpgradeExt for U {} - diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index eaf25e884b3..3b4763d2303 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -18,8 +18,8 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use crate::upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeError}; use crate::{ConnectedPoint, Negotiated}; -use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName}; use futures::{future::Either, prelude::*}; use log::debug; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; @@ -28,8 +28,12 @@ use std::{iter, mem, pin::Pin, task::Context, task::Poll}; pub use multistream_select::Version; /// Applies an upgrade to the inbound and outbound direction of a connection or substream. -pub fn apply(conn: C, up: U, cp: ConnectedPoint, v: Version) - -> Either, OutboundUpgradeApply> +pub fn apply( + conn: C, + up: U, + cp: ConnectedPoint, + v: Version, +) -> Either, OutboundUpgradeApply> where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade> + OutboundUpgrade>, @@ -47,10 +51,16 @@ where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade>, { - let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); + let iter = up + .protocol_info() + .into_iter() + .map(NameWrap as fn(_) -> NameWrap<_>); let future = multistream_select::listener_select_proto(conn, iter); InboundUpgradeApply { - inner: InboundUpgradeApplyState::Init { future, upgrade: up } + inner: InboundUpgradeApplyState::Init { + future, + upgrade: up, + }, } } @@ -58,12 +68,18 @@ where pub fn apply_outbound(conn: C, up: U, v: Version) -> OutboundUpgradeApply where C: AsyncRead + AsyncWrite + Unpin, - U: OutboundUpgrade> + U: OutboundUpgrade>, { - let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); + let iter = up + .protocol_info() + .into_iter() + .map(NameWrap as fn(_) -> NameWrap<_>); let future = multistream_select::dialer_select_proto(conn, iter, v); OutboundUpgradeApply { - inner: OutboundUpgradeApplyState::Init { future, upgrade: up } + inner: OutboundUpgradeApplyState::Init { + future, + upgrade: up, + }, } } @@ -71,9 +87,9 @@ where pub struct InboundUpgradeApply where C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + U: InboundUpgrade>, { - inner: InboundUpgradeApplyState + inner: InboundUpgradeApplyState, } enum InboundUpgradeApplyState @@ -86,9 +102,9 @@ where upgrade: U, }, Upgrade { - future: Pin> + future: Pin>, }, - Undefined + Undefined, } impl Unpin for InboundUpgradeApply @@ -108,36 +124,40 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) { - InboundUpgradeApplyState::Init { mut future, upgrade } => { + InboundUpgradeApplyState::Init { + mut future, + upgrade, + } => { let (info, io) = match Future::poll(Pin::new(&mut future), cx)? { Poll::Ready(x) => x, Poll::Pending => { self.inner = InboundUpgradeApplyState::Init { future, upgrade }; - return Poll::Pending + return Poll::Pending; } }; self.inner = InboundUpgradeApplyState::Upgrade { - future: Box::pin(upgrade.upgrade_inbound(io, info.0)) + future: Box::pin(upgrade.upgrade_inbound(io, info.0)), }; } InboundUpgradeApplyState::Upgrade { mut future } => { match Future::poll(Pin::new(&mut future), cx) { Poll::Pending => { self.inner = InboundUpgradeApplyState::Upgrade { future }; - return Poll::Pending + return Poll::Pending; } Poll::Ready(Ok(x)) => { debug!("Successfully applied negotiated protocol"); - return Poll::Ready(Ok(x)) + return Poll::Ready(Ok(x)); } Poll::Ready(Err(e)) => { debug!("Failed to apply negotiated protocol"); - return Poll::Ready(Err(UpgradeError::Apply(e))) + return Poll::Ready(Err(UpgradeError::Apply(e))); } } } - InboundUpgradeApplyState::Undefined => + InboundUpgradeApplyState::Undefined => { panic!("InboundUpgradeApplyState::poll called after completion") + } } } } @@ -147,24 +167,24 @@ where pub struct OutboundUpgradeApply where C: AsyncRead + AsyncWrite + Unpin, - U: OutboundUpgrade> + U: OutboundUpgrade>, { - inner: OutboundUpgradeApplyState + inner: OutboundUpgradeApplyState, } enum OutboundUpgradeApplyState where C: AsyncRead + AsyncWrite + Unpin, - U: OutboundUpgrade> + U: OutboundUpgrade>, { Init { future: DialerSelectFuture::IntoIter>>, - upgrade: U + upgrade: U, }, Upgrade { - future: Pin> + future: Pin>, }, - Undefined + Undefined, } impl Unpin for OutboundUpgradeApply @@ -184,27 +204,30 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) { - OutboundUpgradeApplyState::Init { mut future, upgrade } => { + OutboundUpgradeApplyState::Init { + mut future, + upgrade, + } => { let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? { Poll::Ready(x) => x, Poll::Pending => { self.inner = OutboundUpgradeApplyState::Init { future, upgrade }; - return Poll::Pending + return Poll::Pending; } }; self.inner = OutboundUpgradeApplyState::Upgrade { - future: Box::pin(upgrade.upgrade_outbound(connection, info.0)) + future: Box::pin(upgrade.upgrade_outbound(connection, info.0)), }; } OutboundUpgradeApplyState::Upgrade { mut future } => { match Future::poll(Pin::new(&mut future), cx) { Poll::Pending => { self.inner = OutboundUpgradeApplyState::Upgrade { future }; - return Poll::Pending + return Poll::Pending; } Poll::Ready(Ok(x)) => { debug!("Successfully applied negotiated protocol"); - return Poll::Ready(Ok(x)) + return Poll::Ready(Ok(x)); } Poll::Ready(Err(e)) => { debug!("Failed to apply negotiated protocol"); @@ -212,8 +235,9 @@ where } } } - OutboundUpgradeApplyState::Undefined => + OutboundUpgradeApplyState::Undefined => { panic!("OutboundUpgradeApplyState::poll called after completion") + } } } } @@ -230,4 +254,3 @@ impl AsRef<[u8]> for NameWrap { self.0.protocol_name() } } - diff --git a/core/src/upgrade/either.rs b/core/src/upgrade/either.rs index 28db987ccd7..8b5c7f71422 100644 --- a/core/src/upgrade/either.rs +++ b/core/src/upgrade/either.rs @@ -19,29 +19,32 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - either::{EitherOutput, EitherError, EitherFuture2, EitherName}, - upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo} + either::{EitherError, EitherFuture2, EitherName, EitherOutput}, + upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}, }; /// A type to represent two possible upgrade types (inbound or outbound). #[derive(Debug, Clone)] -pub enum EitherUpgrade { A(A), B(B) } +pub enum EitherUpgrade { + A(A), + B(B), +} impl UpgradeInfo for EitherUpgrade where A: UpgradeInfo, - B: UpgradeInfo + B: UpgradeInfo, { type Info = EitherName; type InfoIter = EitherIter< ::IntoIter, - ::IntoIter + ::IntoIter, >; fn protocol_info(&self) -> Self::InfoIter { match self { EitherUpgrade::A(a) => EitherIter::A(a.protocol_info().into_iter()), - EitherUpgrade::B(b) => EitherIter::B(b.protocol_info().into_iter()) + EitherUpgrade::B(b) => EitherIter::B(b.protocol_info().into_iter()), } } } @@ -57,9 +60,13 @@ where fn upgrade_inbound(self, sock: C, info: Self::Info) -> Self::Future { match (self, info) { - (EitherUpgrade::A(a), EitherName::A(info)) => EitherFuture2::A(a.upgrade_inbound(sock, info)), - (EitherUpgrade::B(b), EitherName::B(info)) => EitherFuture2::B(b.upgrade_inbound(sock, info)), - _ => panic!("Invalid invocation of EitherUpgrade::upgrade_inbound") + (EitherUpgrade::A(a), EitherName::A(info)) => { + EitherFuture2::A(a.upgrade_inbound(sock, info)) + } + (EitherUpgrade::B(b), EitherName::B(info)) => { + EitherFuture2::B(b.upgrade_inbound(sock, info)) + } + _ => panic!("Invalid invocation of EitherUpgrade::upgrade_inbound"), } } } @@ -75,36 +82,42 @@ where fn upgrade_outbound(self, sock: C, info: Self::Info) -> Self::Future { match (self, info) { - (EitherUpgrade::A(a), EitherName::A(info)) => EitherFuture2::A(a.upgrade_outbound(sock, info)), - (EitherUpgrade::B(b), EitherName::B(info)) => EitherFuture2::B(b.upgrade_outbound(sock, info)), - _ => panic!("Invalid invocation of EitherUpgrade::upgrade_outbound") + (EitherUpgrade::A(a), EitherName::A(info)) => { + EitherFuture2::A(a.upgrade_outbound(sock, info)) + } + (EitherUpgrade::B(b), EitherName::B(info)) => { + EitherFuture2::B(b.upgrade_outbound(sock, info)) + } + _ => panic!("Invalid invocation of EitherUpgrade::upgrade_outbound"), } } } /// A type to represent two possible `Iterator` types. #[derive(Debug, Clone)] -pub enum EitherIter { A(A), B(B) } +pub enum EitherIter { + A(A), + B(B), +} impl Iterator for EitherIter where A: Iterator, - B: Iterator + B: Iterator, { type Item = EitherName; fn next(&mut self) -> Option { match self { EitherIter::A(a) => a.next().map(EitherName::A), - EitherIter::B(b) => b.next().map(EitherName::B) + EitherIter::B(b) => b.next().map(EitherName::B), } } fn size_hint(&self) -> (usize, Option) { match self { EitherIter::A(a) => a.size_hint(), - EitherIter::B(b) => b.size_hint() + EitherIter::B(b) => b.size_hint(), } } } - diff --git a/core/src/upgrade/error.rs b/core/src/upgrade/error.rs index de0ecadbd51..2bbe95ecf2a 100644 --- a/core/src/upgrade/error.rs +++ b/core/src/upgrade/error.rs @@ -33,7 +33,7 @@ pub enum UpgradeError { impl UpgradeError { pub fn map_err(self, f: F) -> UpgradeError where - F: FnOnce(E) -> T + F: FnOnce(E) -> T, { match self { UpgradeError::Select(e) => UpgradeError::Select(e), @@ -43,7 +43,7 @@ impl UpgradeError { pub fn into_err(self) -> UpgradeError where - T: From + T: From, { self.map_err(Into::into) } @@ -51,7 +51,7 @@ impl UpgradeError { impl fmt::Display for UpgradeError where - E: fmt::Display + E: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -63,7 +63,7 @@ where impl std::error::Error for UpgradeError where - E: std::error::Error + 'static + E: std::error::Error + 'static, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { @@ -78,4 +78,3 @@ impl From for UpgradeError { UpgradeError::Select(e) } } - diff --git a/core/src/upgrade/from_fn.rs b/core/src/upgrade/from_fn.rs index 0c8947e5b30..97bbc2eb292 100644 --- a/core/src/upgrade/from_fn.rs +++ b/core/src/upgrade/from_fn.rs @@ -18,7 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{Endpoint, upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeInfo}}; +use crate::{ + upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeInfo}, + Endpoint, +}; use futures::prelude::*; use std::iter; diff --git a/core/src/upgrade/map.rs b/core/src/upgrade/map.rs index 2f5ca31e207..c5fe34f44b5 100644 --- a/core/src/upgrade/map.rs +++ b/core/src/upgrade/map.rs @@ -24,7 +24,10 @@ use std::{pin::Pin, task::Context, task::Poll}; /// Wraps around an upgrade and applies a closure to the output. #[derive(Debug, Clone)] -pub struct MapInboundUpgrade { upgrade: U, fun: F } +pub struct MapInboundUpgrade { + upgrade: U, + fun: F, +} impl MapInboundUpgrade { pub fn new(upgrade: U, fun: F) -> Self { @@ -34,7 +37,7 @@ impl MapInboundUpgrade { impl UpgradeInfo for MapInboundUpgrade where - U: UpgradeInfo + U: UpgradeInfo, { type Info = U::Info; type InfoIter = U::InfoIter; @@ -47,7 +50,7 @@ where impl InboundUpgrade for MapInboundUpgrade where U: InboundUpgrade, - F: FnOnce(U::Output) -> T + F: FnOnce(U::Output) -> T, { type Output = T; type Error = U::Error; @@ -56,7 +59,7 @@ where fn upgrade_inbound(self, sock: C, info: Self::Info) -> Self::Future { MapFuture { inner: self.upgrade.upgrade_inbound(sock, info), - map: Some(self.fun) + map: Some(self.fun), } } } @@ -76,7 +79,10 @@ where /// Wraps around an upgrade and applies a closure to the output. #[derive(Debug, Clone)] -pub struct MapOutboundUpgrade { upgrade: U, fun: F } +pub struct MapOutboundUpgrade { + upgrade: U, + fun: F, +} impl MapOutboundUpgrade { pub fn new(upgrade: U, fun: F) -> Self { @@ -86,7 +92,7 @@ impl MapOutboundUpgrade { impl UpgradeInfo for MapOutboundUpgrade where - U: UpgradeInfo + U: UpgradeInfo, { type Info = U::Info; type InfoIter = U::InfoIter; @@ -112,7 +118,7 @@ where impl OutboundUpgrade for MapOutboundUpgrade where U: OutboundUpgrade, - F: FnOnce(U::Output) -> T + F: FnOnce(U::Output) -> T, { type Output = T; type Error = U::Error; @@ -121,14 +127,17 @@ where fn upgrade_outbound(self, sock: C, info: Self::Info) -> Self::Future { MapFuture { inner: self.upgrade.upgrade_outbound(sock, info), - map: Some(self.fun) + map: Some(self.fun), } } } /// Wraps around an upgrade and applies a closure to the error. #[derive(Debug, Clone)] -pub struct MapInboundUpgradeErr { upgrade: U, fun: F } +pub struct MapInboundUpgradeErr { + upgrade: U, + fun: F, +} impl MapInboundUpgradeErr { pub fn new(upgrade: U, fun: F) -> Self { @@ -138,7 +147,7 @@ impl MapInboundUpgradeErr { impl UpgradeInfo for MapInboundUpgradeErr where - U: UpgradeInfo + U: UpgradeInfo, { type Info = U::Info; type InfoIter = U::InfoIter; @@ -151,7 +160,7 @@ where impl InboundUpgrade for MapInboundUpgradeErr where U: InboundUpgrade, - F: FnOnce(U::Error) -> T + F: FnOnce(U::Error) -> T, { type Output = U::Output; type Error = T; @@ -160,7 +169,7 @@ where fn upgrade_inbound(self, sock: C, info: Self::Info) -> Self::Future { MapErrFuture { fut: self.upgrade.upgrade_inbound(sock, info), - fun: Some(self.fun) + fun: Some(self.fun), } } } @@ -180,7 +189,10 @@ where /// Wraps around an upgrade and applies a closure to the error. #[derive(Debug, Clone)] -pub struct MapOutboundUpgradeErr { upgrade: U, fun: F } +pub struct MapOutboundUpgradeErr { + upgrade: U, + fun: F, +} impl MapOutboundUpgradeErr { pub fn new(upgrade: U, fun: F) -> Self { @@ -190,7 +202,7 @@ impl MapOutboundUpgradeErr { impl UpgradeInfo for MapOutboundUpgradeErr where - U: UpgradeInfo + U: UpgradeInfo, { type Info = U::Info; type InfoIter = U::InfoIter; @@ -203,7 +215,7 @@ where impl OutboundUpgrade for MapOutboundUpgradeErr where U: OutboundUpgrade, - F: FnOnce(U::Error) -> T + F: FnOnce(U::Error) -> T, { type Output = U::Output; type Error = T; @@ -212,14 +224,14 @@ where fn upgrade_outbound(self, sock: C, info: Self::Info) -> Self::Future { MapErrFuture { fut: self.upgrade.upgrade_outbound(sock, info), - fun: Some(self.fun) + fun: Some(self.fun), } } } impl InboundUpgrade for MapOutboundUpgradeErr where - U: InboundUpgrade + U: InboundUpgrade, { type Output = U::Output; type Error = U::Error; @@ -283,4 +295,3 @@ where } } } - diff --git a/core/src/upgrade/optional.rs b/core/src/upgrade/optional.rs index 02dc3c48f78..c661a4f0170 100644 --- a/core/src/upgrade/optional.rs +++ b/core/src/upgrade/optional.rs @@ -112,8 +112,4 @@ where } } -impl ExactSizeIterator for Iter -where - T: ExactSizeIterator -{ -} +impl ExactSizeIterator for Iter where T: ExactSizeIterator {} diff --git a/core/src/upgrade/select.rs b/core/src/upgrade/select.rs index 8fa4c5b8a7a..d1a8cabca2f 100644 --- a/core/src/upgrade/select.rs +++ b/core/src/upgrade/select.rs @@ -19,8 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - either::{EitherOutput, EitherError, EitherFuture2, EitherName}, - upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo} + either::{EitherError, EitherFuture2, EitherName, EitherOutput}, + upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}, }; /// Upgrade that combines two upgrades into one. Supports all the protocols supported by either @@ -42,16 +42,19 @@ impl SelectUpgrade { impl UpgradeInfo for SelectUpgrade where A: UpgradeInfo, - B: UpgradeInfo + B: UpgradeInfo, { type Info = EitherName; type InfoIter = InfoIterChain< ::IntoIter, - ::IntoIter + ::IntoIter, >; fn protocol_info(&self) -> Self::InfoIter { - InfoIterChain(self.0.protocol_info().into_iter(), self.1.protocol_info().into_iter()) + InfoIterChain( + self.0.protocol_info().into_iter(), + self.1.protocol_info().into_iter(), + ) } } @@ -67,7 +70,7 @@ where fn upgrade_inbound(self, sock: C, info: Self::Info) -> Self::Future { match info { EitherName::A(info) => EitherFuture2::A(self.0.upgrade_inbound(sock, info)), - EitherName::B(info) => EitherFuture2::B(self.1.upgrade_inbound(sock, info)) + EitherName::B(info) => EitherFuture2::B(self.1.upgrade_inbound(sock, info)), } } } @@ -84,7 +87,7 @@ where fn upgrade_outbound(self, sock: C, info: Self::Info) -> Self::Future { match info { EitherName::A(info) => EitherFuture2::A(self.0.upgrade_outbound(sock, info)), - EitherName::B(info) => EitherFuture2::B(self.1.upgrade_outbound(sock, info)) + EitherName::B(info) => EitherFuture2::B(self.1.upgrade_outbound(sock, info)), } } } @@ -96,16 +99,16 @@ pub struct InfoIterChain(A, B); impl Iterator for InfoIterChain where A: Iterator, - B: Iterator + B: Iterator, { type Item = EitherName; fn next(&mut self) -> Option { if let Some(info) = self.0.next() { - return Some(EitherName::A(info)) + return Some(EitherName::A(info)); } if let Some(info) = self.1.next() { - return Some(EitherName::B(info)) + return Some(EitherName::B(info)); } None } @@ -117,4 +120,3 @@ where (min1.saturating_add(min2), max) } } - diff --git a/core/src/upgrade/transfer.rs b/core/src/upgrade/transfer.rs index 500ece523c5..fd8127758f1 100644 --- a/core/src/upgrade/transfer.rs +++ b/core/src/upgrade/transfer.rs @@ -29,9 +29,10 @@ use std::{error, fmt, io}; /// /// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is /// > compatible with what [`read_length_prefixed`] expects. -pub async fn write_length_prefixed(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>) - -> Result<(), io::Error> -{ +pub async fn write_length_prefixed( + socket: &mut (impl AsyncWrite + Unpin), + data: impl AsRef<[u8]>, +) -> Result<(), io::Error> { write_varint(socket, data.as_ref().len()).await?; socket.write_all(data.as_ref()).await?; socket.flush().await?; @@ -44,11 +45,15 @@ pub async fn write_length_prefixed(socket: &mut (impl AsyncWrite + Unpin), data: /// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is /// > compatible with what `read_one` expects. /// -#[deprecated(since = "0.29.0", note = "Use `write_length_prefixed` instead. You will need to manually close the stream using `socket.close().await`.")] +#[deprecated( + since = "0.29.0", + note = "Use `write_length_prefixed` instead. You will need to manually close the stream using `socket.close().await`." +)] #[allow(dead_code)] -pub async fn write_one(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>) - -> Result<(), io::Error> -{ +pub async fn write_one( + socket: &mut (impl AsyncWrite + Unpin), + data: impl AsRef<[u8]>, +) -> Result<(), io::Error> { write_varint(socket, data.as_ref().len()).await?; socket.write_all(data.as_ref()).await?; socket.close().await?; @@ -61,9 +66,10 @@ pub async fn write_one(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef< /// > compatible with what `read_one` expects. #[deprecated(since = "0.29.0", note = "Use `write_length_prefixed` instead.")] #[allow(dead_code)] -pub async fn write_with_len_prefix(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>) - -> Result<(), io::Error> -{ +pub async fn write_with_len_prefix( + socket: &mut (impl AsyncWrite + Unpin), + data: impl AsRef<[u8]>, +) -> Result<(), io::Error> { write_varint(socket, data.as_ref().len()).await?; socket.write_all(data.as_ref()).await?; socket.flush().await?; @@ -73,9 +79,10 @@ pub async fn write_with_len_prefix(socket: &mut (impl AsyncWrite + Unpin), data: /// Writes a variable-length integer to the `socket`. /// /// > **Note**: Does **NOT** flush the socket. -pub async fn write_varint(socket: &mut (impl AsyncWrite + Unpin), len: usize) - -> Result<(), io::Error> -{ +pub async fn write_varint( + socket: &mut (impl AsyncWrite + Unpin), + len: usize, +) -> Result<(), io::Error> { let mut len_data = unsigned_varint::encode::usize_buffer(); let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len(); socket.write_all(&len_data[..encoded_len]).await?; @@ -95,7 +102,7 @@ pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result { // Reaching EOF before finishing to read the length is an error, unless the EOF is // at the very beginning of the substream, in which case we assume that the data is @@ -116,7 +123,7 @@ pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result { return Err(io::Error::new( io::ErrorKind::InvalidData, - "overflow in variable-length integer" + "overflow in variable-length integer", )); } // TODO: why do we have a `__Nonexhaustive` variant in the error? I don't know how to process it @@ -134,11 +141,19 @@ pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result **Note**: Assumes that a variable-length prefix indicates the length of the message. This is /// > compatible with what [`write_length_prefixed`] does. -pub async fn read_length_prefixed(socket: &mut (impl AsyncRead + Unpin), max_size: usize) -> io::Result> -{ +pub async fn read_length_prefixed( + socket: &mut (impl AsyncRead + Unpin), + max_size: usize, +) -> io::Result> { let len = read_varint(socket).await?; if len > max_size { - return Err(io::Error::new(io::ErrorKind::InvalidData, format!("Received data size ({} bytes) exceeds maximum ({} bytes)", len, max_size))) + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "Received data size ({} bytes) exceeds maximum ({} bytes)", + len, max_size + ), + )); } let mut buf = vec![0; len]; @@ -157,9 +172,10 @@ pub async fn read_length_prefixed(socket: &mut (impl AsyncRead + Unpin), max_siz /// > compatible with what `write_one` does. #[deprecated(since = "0.29.0", note = "Use `read_length_prefixed` instead.")] #[allow(dead_code, deprecated)] -pub async fn read_one(socket: &mut (impl AsyncRead + Unpin), max_size: usize) - -> Result, ReadOneError> -{ +pub async fn read_one( + socket: &mut (impl AsyncRead + Unpin), + max_size: usize, +) -> Result, ReadOneError> { let len = read_varint(socket).await?; if len > max_size { return Err(ReadOneError::TooLarge { @@ -175,7 +191,10 @@ pub async fn read_one(socket: &mut (impl AsyncRead + Unpin), max_size: usize) /// Error while reading one message. #[derive(Debug)] -#[deprecated(since = "0.29.0", note = "Use `read_length_prefixed` instead of `read_one` to avoid depending on this type.")] +#[deprecated( + since = "0.29.0", + note = "Use `read_length_prefixed` instead of `read_one` to avoid depending on this type." +)] pub enum ReadOneError { /// Error on the socket. Io(std::io::Error), @@ -239,7 +258,7 @@ mod tests { } // TODO: rewrite these tests -/* + /* #[test] fn read_one_works() { let original_data = (0..rand::random::() % 10_000) diff --git a/core/tests/connection_limits.rs b/core/tests/connection_limits.rs index 178eacbd192..65e61c4b3c4 100644 --- a/core/tests/connection_limits.rs +++ b/core/tests/connection_limits.rs @@ -20,16 +20,16 @@ mod util; -use futures::{ready, future::poll_fn}; +use futures::{future::poll_fn, ready}; use libp2p_core::multiaddr::{multiaddr, Multiaddr}; use libp2p_core::{ - PeerId, connection::PendingConnectionError, - network::{NetworkEvent, NetworkConfig, ConnectionLimits, DialError}, + network::{ConnectionLimits, DialError, NetworkConfig, NetworkEvent}, + PeerId, }; use rand::Rng; use std::task::Poll; -use util::{TestHandler, test_network}; +use util::{test_network, TestHandler}; #[test] fn max_outgoing() { @@ -40,14 +40,16 @@ fn max_outgoing() { let mut network = test_network(cfg); let target = PeerId::random(); - for _ in 0 .. outgoing_limit { - network.peer(target.clone()) + for _ in 0..outgoing_limit { + network + .peer(target.clone()) .dial(Multiaddr::empty(), Vec::new(), TestHandler()) .ok() .expect("Unexpected connection limit."); } - match network.peer(target.clone()) + match network + .peer(target.clone()) .dial(Multiaddr::empty(), Vec::new(), TestHandler()) .expect_err("Unexpected dialing success.") { @@ -60,10 +62,14 @@ fn max_outgoing() { let info = network.info(); assert_eq!(info.num_peers(), 0); - assert_eq!(info.connection_counters().num_pending_outgoing(), outgoing_limit); + assert_eq!( + info.connection_counters().num_pending_outgoing(), + outgoing_limit + ); // Abort all dialing attempts. - let mut peer = network.peer(target.clone()) + let mut peer = network + .peer(target.clone()) .into_dialing() .expect("Unexpected peer state"); @@ -72,7 +78,10 @@ fn max_outgoing() { attempt.abort(); } - assert_eq!(network.info().connection_counters().num_pending_outgoing(), 0); + assert_eq!( + network.info().connection_counters().num_pending_outgoing(), + 0 + ); } #[test] @@ -87,35 +96,34 @@ fn max_established_incoming() { let mut network1 = test_network(config(limit)); let mut network2 = test_network(config(limit)); - let listen_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127,0,0,1)), Tcp(0u16)]; + let listen_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1)), Tcp(0u16)]; let _ = network1.listen_on(listen_addr.clone()).unwrap(); let (addr_sender, addr_receiver) = futures::channel::oneshot::channel(); let mut addr_sender = Some(addr_sender); // Spawn the listener. - let listener = async_std::task::spawn(poll_fn(move |cx| { - loop { - match ready!(network1.poll(cx)) { - NetworkEvent::NewListenerAddress { listen_addr, .. } => { - addr_sender.take().unwrap().send(listen_addr).unwrap(); - } - NetworkEvent::IncomingConnection { connection, .. } => { - network1.accept(connection, TestHandler()).unwrap(); - } - NetworkEvent::ConnectionEstablished { .. } => {} - NetworkEvent::IncomingConnectionError { - error: PendingConnectionError::ConnectionLimit(err), .. - } => { - assert_eq!(err.limit, limit); - assert_eq!(err.limit, err.current); - let info = network1.info(); - let counters = info.connection_counters(); - assert_eq!(counters.num_established_incoming(), limit); - assert_eq!(counters.num_established(), limit); - return Poll::Ready(()) - } - e => panic!("Unexpected network event: {:?}", e) + let listener = async_std::task::spawn(poll_fn(move |cx| loop { + match ready!(network1.poll(cx)) { + NetworkEvent::NewListenerAddress { listen_addr, .. } => { + addr_sender.take().unwrap().send(listen_addr).unwrap(); + } + NetworkEvent::IncomingConnection { connection, .. } => { + network1.accept(connection, TestHandler()).unwrap(); } + NetworkEvent::ConnectionEstablished { .. } => {} + NetworkEvent::IncomingConnectionError { + error: PendingConnectionError::ConnectionLimit(err), + .. + } => { + assert_eq!(err.limit, limit); + assert_eq!(err.limit, err.current); + let info = network1.info(); + let counters = info.connection_counters(); + assert_eq!(counters.num_established_incoming(), limit); + assert_eq!(counters.num_established(), limit); + return Poll::Ready(()); + } + e => panic!("Unexpected network event: {:?}", e), } })); @@ -152,15 +160,15 @@ fn max_established_incoming() { let counters = info.connection_counters(); assert_eq!(counters.num_established_outgoing(), limit); assert_eq!(counters.num_established(), limit); - return Poll::Ready(()) + return Poll::Ready(()); } - e => panic!("Unexpected network event: {:?}", e) + e => panic!("Unexpected network event: {:?}", e), } } - }).await + }) + .await }); // Wait for the listener to complete. async_std::task::block_on(listener); } - diff --git a/core/tests/network_dial_error.rs b/core/tests/network_dial_error.rs index 2edb133ffc9..224d7950eac 100644 --- a/core/tests/network_dial_error.rs +++ b/core/tests/network_dial_error.rs @@ -23,14 +23,14 @@ mod util; use futures::prelude::*; use libp2p_core::multiaddr::multiaddr; use libp2p_core::{ - PeerId, connection::PendingConnectionError, multiaddr::Protocol, - network::{NetworkEvent, NetworkConfig}, + network::{NetworkConfig, NetworkEvent}, + PeerId, }; use rand::seq::SliceRandom; use std::{io, task::Poll}; -use util::{TestHandler, test_network}; +use util::{test_network, TestHandler}; #[test] fn deny_incoming_connec() { @@ -39,16 +39,16 @@ fn deny_incoming_connec() { let mut swarm1 = test_network(NetworkConfig::default()); let mut swarm2 = test_network(NetworkConfig::default()); - swarm1.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); + swarm1 + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); - let address = async_std::task::block_on(future::poll_fn(|cx| { - match swarm1.poll(cx) { - Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { - Poll::Ready(listen_addr) - } - Poll::Pending => Poll::Pending, - _ => panic!("Was expecting the listen address to be reported"), + let address = async_std::task::block_on(future::poll_fn(|cx| match swarm1.poll(cx) { + Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { + Poll::Ready(listen_addr) } + Poll::Pending => Poll::Pending, + _ => panic!("Was expecting the listen address to be reported"), })); swarm2 @@ -68,23 +68,26 @@ fn deny_incoming_connec() { attempts_remaining: 0, peer_id, multiaddr, - error: PendingConnectionError::Transport(_) + error: PendingConnectionError::Transport(_), }) => { assert_eq!(&peer_id, swarm1.local_peer_id()); - assert_eq!(multiaddr, address.clone().with(Protocol::P2p(peer_id.into()))); + assert_eq!( + multiaddr, + address.clone().with(Protocol::P2p(peer_id.into())) + ); return Poll::Ready(Ok(())); - }, + } Poll::Ready(_) => unreachable!(), Poll::Pending => (), } Poll::Pending - })).unwrap(); + })) + .unwrap(); } #[test] fn dial_self() { - // Check whether dialing ourselves correctly fails. // // Dialing the same address we're listening should result in three events: @@ -96,16 +99,16 @@ fn dial_self() { // The last two can happen in any order. let mut swarm = test_network(NetworkConfig::default()); - swarm.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); + swarm + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); - let local_address = async_std::task::block_on(future::poll_fn(|cx| { - match swarm.poll(cx) { - Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { - Poll::Ready(listen_addr) - } - Poll::Pending => Poll::Pending, - _ => panic!("Was expecting the listen address to be reported"), + let local_address = async_std::task::block_on(future::poll_fn(|cx| match swarm.poll(cx) { + Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { + Poll::Ready(listen_addr) } + Poll::Pending => Poll::Pending, + _ => panic!("Was expecting the listen address to be reported"), })); swarm.dial(&local_address, TestHandler()).unwrap(); @@ -124,30 +127,29 @@ fn dial_self() { assert_eq!(multiaddr, local_address); got_dial_err = true; if got_inc_err { - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } - }, - Poll::Ready(NetworkEvent::IncomingConnectionError { - local_addr, .. - }) => { + } + Poll::Ready(NetworkEvent::IncomingConnectionError { local_addr, .. }) => { assert!(!got_inc_err); assert_eq!(local_addr, local_address); got_inc_err = true; if got_dial_err { - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } - }, + } Poll::Ready(NetworkEvent::IncomingConnection { connection, .. }) => { assert_eq!(&connection.local_addr, &local_address); swarm.accept(connection, TestHandler()).unwrap(); - }, + } Poll::Ready(ev) => { panic!("Unexpected event: {:?}", ev) } Poll::Pending => break Poll::Pending, } } - })).unwrap(); + })) + .unwrap(); } #[test] @@ -168,23 +170,19 @@ fn multiple_addresses_err() { let mut swarm = test_network(NetworkConfig::default()); let mut addresses = Vec::new(); - for _ in 0 .. 3 { - addresses.push(multiaddr![ - Ip4([0, 0, 0, 0]), - Tcp(rand::random::()) - ]); + for _ in 0..3 { + addresses.push(multiaddr![Ip4([0, 0, 0, 0]), Tcp(rand::random::())]); } - for _ in 0 .. 5 { - addresses.push(multiaddr![ - Udp(rand::random::()) - ]); + for _ in 0..5 { + addresses.push(multiaddr![Udp(rand::random::())]); } addresses.shuffle(&mut rand::thread_rng()); let first = addresses[0].clone(); let rest = (&addresses[1..]).iter().cloned(); - swarm.peer(target.clone()) + swarm + .peer(target.clone()) .dial(first, rest, TestHandler()) .unwrap(); @@ -195,10 +193,12 @@ fn multiple_addresses_err() { attempts_remaining, peer_id, multiaddr, - error: PendingConnectionError::Transport(_) + error: PendingConnectionError::Transport(_), }) => { assert_eq!(peer_id, target); - let expected = addresses.remove(0).with(Protocol::P2p(target.clone().into())); + let expected = addresses + .remove(0) + .with(Protocol::P2p(target.clone().into())); assert_eq!(multiaddr, expected); if addresses.is_empty() { assert_eq!(attempts_remaining, 0); @@ -206,10 +206,11 @@ fn multiple_addresses_err() { } else { assert_eq!(attempts_remaining, addresses.len() as u32); } - }, + } Poll::Ready(_) => unreachable!(), Poll::Pending => break Poll::Pending, } } - })).unwrap(); + })) + .unwrap(); } diff --git a/core/tests/transport_upgrade.rs b/core/tests/transport_upgrade.rs index 191e4b14e81..f02fb2f3bd7 100644 --- a/core/tests/transport_upgrade.rs +++ b/core/tests/transport_upgrade.rs @@ -22,8 +22,8 @@ mod util; use futures::prelude::*; use libp2p_core::identity; -use libp2p_core::transport::{Transport, MemoryTransport}; -use libp2p_core::upgrade::{self, UpgradeInfo, InboundUpgrade, OutboundUpgrade}; +use libp2p_core::transport::{MemoryTransport, Transport}; +use libp2p_core::upgrade::{self, InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use libp2p_mplex::MplexConfig; use libp2p_noise as noise; use multiaddr::{Multiaddr, Protocol}; @@ -44,7 +44,7 @@ impl UpgradeInfo for HelloUpgrade { impl InboundUpgrade for HelloUpgrade where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = C; type Error = io::Error; @@ -81,7 +81,9 @@ where fn upgrade_pipeline() { let listener_keys = identity::Keypair::generate_ed25519(); let listener_id = listener_keys.public().to_peer_id(); - let listener_noise_keys = noise::Keypair::::new().into_authentic(&listener_keys).unwrap(); + let listener_noise_keys = noise::Keypair::::new() + .into_authentic(&listener_keys) + .unwrap(); let listener_transport = MemoryTransport::default() .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(listener_noise_keys).into_authenticated()) @@ -97,7 +99,9 @@ fn upgrade_pipeline() { let dialer_keys = identity::Keypair::generate_ed25519(); let dialer_id = dialer_keys.public().to_peer_id(); - let dialer_noise_keys = noise::Keypair::::new().into_authentic(&dialer_keys).unwrap(); + let dialer_noise_keys = noise::Keypair::::new() + .into_authentic(&dialer_keys) + .unwrap(); let dialer_transport = MemoryTransport::default() .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(dialer_noise_keys).into_authenticated()) @@ -121,7 +125,7 @@ fn upgrade_pipeline() { let (upgrade, _remote_addr) = match listener.next().await.unwrap().unwrap().into_upgrade() { Some(u) => u, - None => continue + None => continue, }; let (peer, _mplex) = upgrade.await.unwrap(); assert_eq!(peer, dialer_id); @@ -136,4 +140,3 @@ fn upgrade_pipeline() { async_std::task::spawn(server); async_std::task::block_on(client); } - diff --git a/core/tests/util.rs b/core/tests/util.rs index 0437f90867e..0c175448336 100644 --- a/core/tests/util.rs +++ b/core/tests/util.rs @@ -1,22 +1,12 @@ - #![allow(dead_code)] use futures::prelude::*; use libp2p_core::{ - Multiaddr, - PeerId, - Transport, - connection::{ - ConnectionHandler, - ConnectionHandlerEvent, - Substream, - SubstreamEndpoint, - }, + connection::{ConnectionHandler, ConnectionHandlerEvent, Substream, SubstreamEndpoint}, identity, muxing::{StreamMuxer, StreamMuxerBox}, network::{Network, NetworkConfig}, - transport, - upgrade, + transport, upgrade, Multiaddr, PeerId, Transport, }; use libp2p_mplex as mplex; use libp2p_noise as noise; @@ -30,7 +20,9 @@ type TestTransport = transport::Boxed<(PeerId, StreamMuxerBox)>; pub fn test_network(cfg: NetworkConfig) -> TestNetwork { let local_key = identity::Keypair::generate_ed25519(); let local_public_key = local_key.public(); - let noise_keys = noise::Keypair::::new().into_authentic(&local_key).unwrap(); + let noise_keys = noise::Keypair::::new() + .into_authentic(&local_key) + .unwrap(); let transport: TestTransport = tcp::TcpConfig::new() .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) @@ -49,17 +41,21 @@ impl ConnectionHandler for TestHandler { type Substream = Substream; type OutboundOpenInfo = (); - fn inject_substream(&mut self, _: Self::Substream, _: SubstreamEndpoint) - {} + fn inject_substream( + &mut self, + _: Self::Substream, + _: SubstreamEndpoint, + ) { + } - fn inject_event(&mut self, _: Self::InEvent) - {} + fn inject_event(&mut self, _: Self::InEvent) {} - fn inject_address_change(&mut self, _: &Multiaddr) - {} + fn inject_address_change(&mut self, _: &Multiaddr) {} - fn poll(&mut self, _: &mut Context<'_>) - -> Poll, Self::Error>> + fn poll( + &mut self, + _: &mut Context<'_>, + ) -> Poll, Self::Error>> { Poll::Pending } @@ -72,7 +68,7 @@ pub struct CloseMuxer { impl CloseMuxer { pub fn new(m: M) -> CloseMuxer { CloseMuxer { - state: CloseMuxerState::Close(m) + state: CloseMuxerState::Close(m), } } } @@ -85,7 +81,7 @@ pub enum CloseMuxerState { impl Future for CloseMuxer where M: StreamMuxer, - M::Error: From + M::Error: From, { type Output = Result; @@ -95,15 +91,14 @@ where CloseMuxerState::Close(muxer) => { if !muxer.close(cx)?.is_ready() { self.state = CloseMuxerState::Close(muxer); - return Poll::Pending + return Poll::Pending; } - return Poll::Ready(Ok(muxer)) + return Poll::Ready(Ok(muxer)); } - CloseMuxerState::Done => panic!() + CloseMuxerState::Done => panic!(), } } } } -impl Unpin for CloseMuxer { -} +impl Unpin for CloseMuxer {} diff --git a/examples/chat-tokio.rs b/examples/chat-tokio.rs index 45f35192534..202b5b39156 100644 --- a/examples/chat-tokio.rs +++ b/examples/chat-tokio.rs @@ -38,19 +38,19 @@ use futures::StreamExt; use libp2p::{ - Multiaddr, - NetworkBehaviour, - PeerId, - Transport, core::upgrade, - identity, floodsub::{self, Floodsub, FloodsubEvent}, + identity, mdns::{Mdns, MdnsEvent}, mplex, noise, swarm::{NetworkBehaviourEventProcess, SwarmBuilder, SwarmEvent}, // `TokioTcpConfig` is available through the `tcp-tokio` feature. tcp::TokioTcpConfig, + Multiaddr, + NetworkBehaviour, + PeerId, + Transport, }; use std::error::Error; use tokio::io::{self, AsyncBufReadExt}; @@ -72,7 +72,8 @@ async fn main() -> Result<(), Box> { // Create a tokio-based TCP transport use noise for authenticated // encryption and Mplex for multiplexing of substreams on a TCP stream. - let transport = TokioTcpConfig::new().nodelay(true) + let transport = TokioTcpConfig::new() + .nodelay(true) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(mplex::MplexConfig::new()) @@ -95,7 +96,11 @@ async fn main() -> Result<(), Box> { // Called when `floodsub` produces an event. fn inject_event(&mut self, message: FloodsubEvent) { if let FloodsubEvent::Message(message) = message { - println!("Received: '{:?}' from {:?}", String::from_utf8_lossy(&message.data), message.source); + println!( + "Received: '{:?}' from {:?}", + String::from_utf8_lossy(&message.data), + message.source + ); } } } @@ -104,16 +109,18 @@ async fn main() -> Result<(), Box> { // Called when `mdns` produces an event. fn inject_event(&mut self, event: MdnsEvent) { match event { - MdnsEvent::Discovered(list) => + MdnsEvent::Discovered(list) => { for (peer, _) in list { self.floodsub.add_node_to_partial_view(peer); } - MdnsEvent::Expired(list) => + } + MdnsEvent::Expired(list) => { for (peer, _) in list { if !self.mdns.has_node(&peer) { self.floodsub.remove_node_from_partial_view(&peer); } } + } } } } @@ -131,7 +138,9 @@ async fn main() -> Result<(), Box> { SwarmBuilder::new(transport, behaviour, peer_id) // We want the connection background tasks to be spawned // onto the tokio runtime. - .executor(Box::new(|fut| { tokio::spawn(fut); })) + .executor(Box::new(|fut| { + tokio::spawn(fut); + })) .build() }; diff --git a/examples/chat.rs b/examples/chat.rs index fddec5e5e9b..18ef72b96b8 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -31,7 +31,7 @@ //! # If they don't automatically connect //! //! If the nodes don't automatically connect, take note of the listening addresses of the first -//! instance and start the second with one of the addresses as the first argument. In the first +//! instance and start the second with one of the addresses as the first argument. In the first //! terminal window, run: //! //! ```sh @@ -52,16 +52,16 @@ use async_std::{io, task}; use futures::{future, prelude::*}; use libp2p::{ - Multiaddr, - PeerId, - Swarm, - NetworkBehaviour, - identity, floodsub::{self, Floodsub, FloodsubEvent}, + identity, mdns::{Mdns, MdnsConfig, MdnsEvent}, - swarm::{NetworkBehaviourEventProcess, SwarmEvent} + swarm::{NetworkBehaviourEventProcess, SwarmEvent}, + Multiaddr, NetworkBehaviour, PeerId, Swarm, +}; +use std::{ + error::Error, + task::{Context, Poll}, }; -use std::{error::Error, task::{Context, Poll}}; #[async_std::main] async fn main() -> Result<(), Box> { @@ -97,7 +97,11 @@ async fn main() -> Result<(), Box> { // Called when `floodsub` produces an event. fn inject_event(&mut self, message: FloodsubEvent) { if let FloodsubEvent::Message(message) = message { - println!("Received: '{:?}' from {:?}", String::from_utf8_lossy(&message.data), message.source); + println!( + "Received: '{:?}' from {:?}", + String::from_utf8_lossy(&message.data), + message.source + ); } } } @@ -106,16 +110,18 @@ async fn main() -> Result<(), Box> { // Called when `mdns` produces an event. fn inject_event(&mut self, event: MdnsEvent) { match event { - MdnsEvent::Discovered(list) => + MdnsEvent::Discovered(list) => { for (peer, _) in list { self.floodsub.add_node_to_partial_view(peer); } - MdnsEvent::Expired(list) => + } + MdnsEvent::Expired(list) => { for (peer, _) in list { if !self.mdns.has_node(&peer) { self.floodsub.remove_node_from_partial_view(&peer); } } + } } } } @@ -150,11 +156,12 @@ async fn main() -> Result<(), Box> { task::block_on(future::poll_fn(move |cx: &mut Context<'_>| { loop { match stdin.try_poll_next_unpin(cx)? { - Poll::Ready(Some(line)) => swarm.behaviour_mut() + Poll::Ready(Some(line)) => swarm + .behaviour_mut() .floodsub .publish(floodsub_topic.clone(), line.as_bytes()), Poll::Ready(None) => panic!("Stdin closed"), - Poll::Pending => break + Poll::Pending => break, } } loop { diff --git a/examples/distributed-key-value-store.rs b/examples/distributed-key-value-store.rs index 9ab5b7206d7..2e5fa5a8531 100644 --- a/examples/distributed-key-value-store.rs +++ b/examples/distributed-key-value-store.rs @@ -44,26 +44,19 @@ use async_std::{io, task}; use futures::prelude::*; use libp2p::kad::record::store::MemoryStore; use libp2p::kad::{ - AddProviderOk, - Kademlia, - KademliaEvent, - PeerRecord, - PutRecordOk, - QueryResult, - Quorum, - Record, - record::Key, + record::Key, AddProviderOk, Kademlia, KademliaEvent, PeerRecord, PutRecordOk, QueryResult, + Quorum, Record, }; use libp2p::{ - NetworkBehaviour, - PeerId, - Swarm, - development_transport, - identity, + development_transport, identity, mdns::{Mdns, MdnsConfig, MdnsEvent}, - swarm::{NetworkBehaviourEventProcess, SwarmEvent} + swarm::{NetworkBehaviourEventProcess, SwarmEvent}, + NetworkBehaviour, PeerId, Swarm, +}; +use std::{ + error::Error, + task::{Context, Poll}, }; -use std::{error::Error, task::{Context, Poll}}; #[async_std::main] async fn main() -> Result<(), Box> { @@ -80,7 +73,7 @@ async fn main() -> Result<(), Box> { #[derive(NetworkBehaviour)] struct MyBehaviour { kademlia: Kademlia, - mdns: Mdns + mdns: Mdns, } impl NetworkBehaviourEventProcess for MyBehaviour { @@ -112,7 +105,11 @@ async fn main() -> Result<(), Box> { eprintln!("Failed to get providers: {:?}", err); } QueryResult::GetRecord(Ok(ok)) => { - for PeerRecord { record: Record { key, value, .. }, ..} in ok.records { + for PeerRecord { + record: Record { key, value, .. }, + .. + } in ok.records + { println!( "Got record {:?} {:?}", std::str::from_utf8(key.as_ref()).unwrap(), @@ -133,7 +130,8 @@ async fn main() -> Result<(), Box> { eprintln!("Failed to put record: {:?}", err); } QueryResult::StartProviding(Ok(AddProviderOk { key })) => { - println!("Successfully put provider record {:?}", + println!( + "Successfully put provider record {:?}", std::str::from_utf8(key.as_ref()).unwrap() ); } @@ -141,7 +139,7 @@ async fn main() -> Result<(), Box> { eprintln!("Failed to put provider record: {:?}", err); } _ => {} - } + }, _ => {} } } @@ -167,9 +165,11 @@ async fn main() -> Result<(), Box> { task::block_on(future::poll_fn(move |cx: &mut Context<'_>| { loop { match stdin.try_poll_next_unpin(cx)? { - Poll::Ready(Some(line)) => handle_input_line(&mut swarm.behaviour_mut().kademlia, line), + Poll::Ready(Some(line)) => { + handle_input_line(&mut swarm.behaviour_mut().kademlia, line) + } Poll::Ready(None) => panic!("Stdin closed"), - Poll::Pending => break + Poll::Pending => break, } } loop { @@ -209,7 +209,7 @@ fn handle_input_line(kademlia: &mut Kademlia, line: String) { Some(key) => Key::new(&key), None => { eprintln!("Expected key"); - return + return; } } }; @@ -240,8 +240,10 @@ fn handle_input_line(kademlia: &mut Kademlia, line: String) { publisher: None, expires: None, }; - kademlia.put_record(record, Quorum::One).expect("Failed to store record locally."); - }, + kademlia + .put_record(record, Quorum::One) + .expect("Failed to store record locally."); + } Some("PUT_PROVIDER") => { let key = { match args.next() { @@ -253,7 +255,9 @@ fn handle_input_line(kademlia: &mut Kademlia, line: String) { } }; - kademlia.start_providing(key).expect("Failed to start providing key"); + kademlia + .start_providing(key) + .expect("Failed to start providing key"); } _ => { eprintln!("expected GET, GET_PROVIDERS, PUT or PUT_PROVIDER"); diff --git a/examples/gossipsub-chat.rs b/examples/gossipsub-chat.rs index f56fe708075..bbf1190f8c3 100644 --- a/examples/gossipsub-chat.rs +++ b/examples/gossipsub-chat.rs @@ -28,7 +28,7 @@ //! chat members and everyone will receive all messages. //! //! In order to get the nodes to connect, take note of the listening addresses of the first -//! instance and start the second with one of the addresses as the first argument. In the first +//! instance and start the second with one of the addresses as the first argument. In the first //! terminal window, run: //! //! ```sh diff --git a/examples/ipfs-kad.rs b/examples/ipfs-kad.rs index c1e7e5c66d2..b3e6b211b46 100644 --- a/examples/ipfs-kad.rs +++ b/examples/ipfs-kad.rs @@ -25,28 +25,20 @@ use async_std::task; use futures::StreamExt; +use libp2p::kad::record::store::MemoryStore; +use libp2p::kad::{GetClosestPeersError, Kademlia, KademliaConfig, KademliaEvent, QueryResult}; use libp2p::{ - Multiaddr, + development_transport, identity, swarm::{Swarm, SwarmEvent}, - PeerId, - identity, - development_transport -}; -use libp2p::kad::{ - Kademlia, - KademliaConfig, - KademliaEvent, - GetClosestPeersError, - QueryResult, + Multiaddr, PeerId, }; -use libp2p::kad::record::store::MemoryStore; use std::{env, error::Error, str::FromStr, time::Duration}; const BOOTNODES: [&'static str; 4] = [ "QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN", "QmQCU2EcMqAqQPR2i9bChDtGNJchTbq5TbXJJ16u19uLTa", "QmbLHAnMoJPWSCR5Zhtx6BHJX9KiKNN6tpvbUcqanj75Nb", - "QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt" + "QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt", ]; #[async_std::main] @@ -96,9 +88,10 @@ async fn main() -> Result<(), Box> { if let SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { result: QueryResult::GetClosestPeers(result), .. - }) = event { + }) = event + { match result { - Ok(ok) => + Ok(ok) => { if !ok.peers.is_empty() { println!("Query finished with closest peers: {:#?}", ok.peers) } else { @@ -106,7 +99,8 @@ async fn main() -> Result<(), Box> { // should always be at least 1 reachable peer. println!("Query finished with no closest peers.") } - Err(GetClosestPeersError::Timeout { peers, .. }) => + } + Err(GetClosestPeersError::Timeout { peers, .. }) => { if !peers.is_empty() { println!("Query timed out with closest peers: {:#?}", peers) } else { @@ -114,6 +108,7 @@ async fn main() -> Result<(), Box> { // should always be at least 1 reachable peer. println!("Query timed out with no closest peers."); } + } }; break; diff --git a/examples/mdns-passive-discovery.rs b/examples/mdns-passive-discovery.rs index bce18dea1ee..a63ec7d5afe 100644 --- a/examples/mdns-passive-discovery.rs +++ b/examples/mdns-passive-discovery.rs @@ -20,10 +20,10 @@ use futures::StreamExt; use libp2p::{ - identity, - mdns::{Mdns, MdnsConfig, MdnsEvent}, + identity, + mdns::{Mdns, MdnsConfig, MdnsEvent}, swarm::{Swarm, SwarmEvent}, - PeerId + PeerId, }; use std::error::Error; diff --git a/examples/ping.rs b/examples/ping.rs index 151e9a5b5dd..f38b4fc4011 100644 --- a/examples/ping.rs +++ b/examples/ping.rs @@ -79,7 +79,7 @@ fn main() -> Result<(), Box> { block_on(future::poll_fn(move |cx| loop { match swarm.poll_next_unpin(cx) { Poll::Ready(Some(event)) => match event { - SwarmEvent::NewListenAddr{ address, .. } => println!("Listening on {:?}", address), + SwarmEvent::NewListenAddr { address, .. } => println!("Listening on {:?}", address), SwarmEvent::Behaviour(event) => println!("{:?}", event), _ => {} }, diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 34344bcd556..7a8c75daa6f 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -20,11 +20,16 @@ //! Protocol negotiation strategies for the peer acting as the dialer. +use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}; use crate::{Negotiated, NegotiationError, Version}; -use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine}; use futures::{future::Either, prelude::*}; -use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}}; +use std::{ + convert::TryFrom as _, + iter, mem, + pin::Pin, + task::{Context, Poll}, +}; /// Returns a `Future` that negotiates a protocol on the given I/O stream /// for a peer acting as the _dialer_ (or _initiator_). @@ -48,17 +53,17 @@ use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}}; pub fn dialer_select_proto( inner: R, protocols: I, - version: Version + version: Version, ) -> DialerSelectFuture where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { let iter = protocols.into_iter(); // We choose between the "serial" and "parallel" strategies based on the number of protocols. if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) { - Either::Left(dialer_select_proto_serial(inner, iter, version)) + Either::Left(dialer_select_proto_serial(inner, iter, version)) } else { Either::Right(dialer_select_proto_parallel(inner, iter, version)) } @@ -79,12 +84,12 @@ pub type DialerSelectFuture = Either, DialerSelectPa pub(crate) fn dialer_select_proto_serial( inner: R, protocols: I, - version: Version + version: Version, ) -> DialerSelectSeq where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { let protocols = protocols.into_iter().peekable(); DialerSelectSeq { @@ -92,7 +97,7 @@ where protocols, state: SeqState::SendHeader { io: MessageIO::new(inner), - } + }, } } @@ -108,20 +113,20 @@ where pub(crate) fn dialer_select_proto_parallel( inner: R, protocols: I, - version: Version + version: Version, ) -> DialerSelectPar where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { let protocols = protocols.into_iter(); DialerSelectPar { version, protocols, state: ParState::SendHeader { - io: MessageIO::new(inner) - } + io: MessageIO::new(inner), + }, } } @@ -136,11 +141,11 @@ pub struct DialerSelectSeq { } enum SeqState { - SendHeader { io: MessageIO, }, + SendHeader { io: MessageIO }, SendProtocol { io: MessageIO, protocol: N }, FlushProtocol { io: MessageIO, protocol: N }, AwaitProtocol { io: MessageIO, protocol: N }, - Done + Done, } impl Future for DialerSelectSeq @@ -149,7 +154,7 @@ where // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, I: Iterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { type Output = Result<(I::Item, Negotiated), NegotiationError>; @@ -160,11 +165,11 @@ where match mem::replace(this.state, SeqState::Done) { SeqState::SendHeader { mut io } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = SeqState::SendHeader { io }; - return Poll::Pending - }, + return Poll::Pending; + } } let h = HeaderLine::from(*this.version); @@ -181,11 +186,11 @@ where SeqState::SendProtocol { mut io, protocol } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = SeqState::SendProtocol { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } } let p = Protocol::try_from(protocol.as_ref())?; @@ -207,7 +212,7 @@ where log::debug!("Dialer: Expecting proposed protocol: {}", p); let hl = HeaderLine::from(Version::V1Lazy); let io = Negotiated::expecting(io.into_reader(), p, Some(hl)); - return Poll::Ready(Ok((protocol, io))) + return Poll::Ready(Ok((protocol, io))); } } } @@ -218,8 +223,8 @@ where Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol }, Poll::Pending => { *this.state = SeqState::FlushProtocol { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } } } @@ -228,7 +233,7 @@ where Poll::Ready(Some(msg)) => msg, Poll::Pending => { *this.state = SeqState::AwaitProtocol { io, protocol }; - return Poll::Pending + return Poll::Pending; } // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -246,16 +251,18 @@ where return Poll::Ready(Ok((protocol, io))); } Message::NotAvailable => { - log::debug!("Dialer: Received rejection of protocol: {}", - String::from_utf8_lossy(protocol.as_ref())); + log::debug!( + "Dialer: Received rejection of protocol: {}", + String::from_utf8_lossy(protocol.as_ref()) + ); let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; *this.state = SeqState::SendProtocol { io, protocol } } - _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } } - SeqState::Done => panic!("SeqState::poll called after completion") + SeqState::Done => panic!("SeqState::poll called after completion"), } } } @@ -277,7 +284,7 @@ enum ParState { Flush { io: MessageIO }, RecvProtocols { io: MessageIO }, SendProtocol { io: MessageIO, protocol: N }, - Done + Done, } impl Future for DialerSelectPar @@ -286,7 +293,7 @@ where // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, I: Iterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { type Output = Result<(I::Item, Negotiated), NegotiationError>; @@ -297,11 +304,11 @@ where match mem::replace(this.state, ParState::Done) { ParState::SendHeader { mut io } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = ParState::SendHeader { io }; - return Poll::Pending - }, + return Poll::Pending; + } } let msg = Message::Header(HeaderLine::from(*this.version)); @@ -314,11 +321,11 @@ where ParState::SendProtocolsRequest { mut io } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = ParState::SendProtocolsRequest { io }; - return Poll::Pending - }, + return Poll::Pending; + } } if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) { @@ -329,22 +336,20 @@ where *this.state = ParState::Flush { io } } - ParState::Flush { mut io } => { - match Pin::new(&mut io).poll_flush(cx)? { - Poll::Ready(()) => *this.state = ParState::RecvProtocols { io }, - Poll::Pending => { - *this.state = ParState::Flush { io }; - return Poll::Pending - }, + ParState::Flush { mut io } => match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => *this.state = ParState::RecvProtocols { io }, + Poll::Pending => { + *this.state = ParState::Flush { io }; + return Poll::Pending; } - } + }, ParState::RecvProtocols { mut io } => { let msg = match Pin::new(&mut io).poll_next(cx)? { Poll::Ready(Some(msg)) => msg, Poll::Pending => { *this.state = ParState::RecvProtocols { io }; - return Poll::Pending + return Poll::Pending; } // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -357,12 +362,15 @@ where *this.state = ParState::RecvProtocols { io } } Message::Protocols(supported) => { - let protocol = this.protocols.by_ref() - .find(|p| supported.iter().any(|s| - s.as_ref() == p.as_ref())) + let protocol = this + .protocols + .by_ref() + .find(|p| supported.iter().any(|s| s.as_ref() == p.as_ref())) .ok_or(NegotiationError::Failed)?; - log::debug!("Dialer: Found supported protocol: {}", - String::from_utf8_lossy(protocol.as_ref())); + log::debug!( + "Dialer: Found supported protocol: {}", + String::from_utf8_lossy(protocol.as_ref()) + ); *this.state = ParState::SendProtocol { io, protocol }; } _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), @@ -371,11 +379,11 @@ where ParState::SendProtocol { mut io, protocol } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = ParState::SendProtocol { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } } let p = Protocol::try_from(protocol.as_ref())?; @@ -386,10 +394,10 @@ where log::debug!("Dialer: Expecting proposed protocol: {}", p); let io = Negotiated::expecting(io.into_reader(), p, None); - return Poll::Ready(Ok((protocol, io))) + return Poll::Ready(Ok((protocol, io))); } - ParState::Done => panic!("ParState::poll called after completion") + ParState::Done => panic!("ParState::poll called after completion"), } } } diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 593c915ac2b..abb622eed30 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -18,9 +18,15 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::{Bytes, BytesMut, Buf as _, BufMut as _}; -use futures::{prelude::*, io::IoSlice}; -use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16}; +use bytes::{Buf as _, BufMut as _, Bytes, BytesMut}; +use futures::{io::IoSlice, prelude::*}; +use std::{ + convert::TryFrom as _, + io, + pin::Pin, + task::{Context, Poll}, + u16, +}; const MAX_LEN_BYTES: u16 = 2; const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; @@ -50,7 +56,10 @@ pub struct LengthDelimited { #[derive(Debug, Copy, Clone, PartialEq, Eq)] enum ReadState { /// We are currently reading the length of the next frame of data. - ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize }, + ReadLength { + buf: [u8; MAX_LEN_BYTES as usize], + pos: usize, + }, /// We are currently reading the frame of data itself. ReadData { len: u16, pos: usize }, } @@ -59,7 +68,7 @@ impl Default for ReadState { fn default() -> Self { ReadState::ReadLength { buf: [0; MAX_LEN_BYTES as usize], - pos: 0 + pos: 0, } } } @@ -106,10 +115,12 @@ impl LengthDelimited { /// /// After this method returns `Poll::Ready`, the write buffer of frames /// submitted to the `Sink` is guaranteed to be empty. - pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll> + pub fn poll_write_buffer( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> where - R: AsyncWrite + R: AsyncWrite, { let mut this = self.project(); @@ -119,7 +130,8 @@ impl LengthDelimited { Poll::Ready(Ok(0)) => { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, - "Failed to write buffered frame."))) + "Failed to write buffered frame.", + ))) } Poll::Ready(Ok(n)) => this.write_buffer.advance(n), Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), @@ -132,7 +144,7 @@ impl LengthDelimited { impl Stream for LengthDelimited where - R: AsyncRead + R: AsyncRead, { type Item = Result; @@ -142,7 +154,7 @@ where loop { match this.read_state { ReadState::ReadLength { buf, pos } => { - match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) { + match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) { Poll::Ready(Ok(0)) => { if *pos == 0 { return Poll::Ready(None); @@ -160,11 +172,10 @@ where if (buf[*pos - 1] & 0x80) == 0 { // MSB is not set, indicating the end of the length prefix. - let (len, _) = unsigned_varint::decode::u16(buf) - .map_err(|e| { - log::debug!("invalid length prefix: {}", e); - io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") - })?; + let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| { + log::debug!("invalid length prefix: {}", e); + io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") + })?; if len >= 1 { *this.read_state = ReadState::ReadData { len, pos: 0 }; @@ -179,12 +190,19 @@ where // See the module documentation about the max frame len. return Poll::Ready(Some(Err(io::Error::new( io::ErrorKind::InvalidData, - "Maximum frame length exceeded")))); + "Maximum frame length exceeded", + )))); } } ReadState::ReadData { len, pos } => { - match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) { - Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))), + match this + .inner + .as_mut() + .poll_read(cx, &mut this.read_buffer[*pos..]) + { + Poll::Ready(Ok(0)) => { + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))) + } Poll::Ready(Ok(n)) => *pos += n, Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), @@ -214,7 +232,7 @@ where // implied to be roughly 2 * MAX_FRAME_SIZE. if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize { match self.as_mut().poll_write_buffer(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -233,7 +251,8 @@ where _ => { return Err(io::Error::new( io::ErrorKind::InvalidData, - "Maximum frame size exceeded.")) + "Maximum frame size exceeded.", + )) } }; @@ -249,7 +268,7 @@ where fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Write all buffered frame data to the underlying I/O stream. match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -264,7 +283,7 @@ where fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Write all buffered frame data to the underlying I/O stream. match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -283,7 +302,7 @@ where #[derive(Debug)] pub struct LengthDelimitedReader { #[pin] - inner: LengthDelimited + inner: LengthDelimited, } impl LengthDelimitedReader { @@ -306,7 +325,7 @@ impl LengthDelimitedReader { impl Stream for LengthDelimitedReader where - R: AsyncRead + R: AsyncRead, { type Item = Result; @@ -317,17 +336,19 @@ where impl AsyncWrite for LengthDelimitedReader where - R: AsyncWrite + R: AsyncWrite, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) - -> Poll> - { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { // `this` here designates the `LengthDelimited`. let mut this = self.project().inner; // We need to flush any data previously written with the `LengthDelimited`. match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -344,15 +365,17 @@ where self.project().inner.poll_close(cx) } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) - -> Poll> - { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { // `this` here designates the `LengthDelimited`. let mut this = self.project().inner; // We need to flush any data previously written with the `LengthDelimited`. match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -366,7 +389,7 @@ where mod tests { use crate::length_delimited::LengthDelimited; use async_std::net::{TcpListener, TcpStream}; - use futures::{prelude::*, io::Cursor}; + use futures::{io::Cursor, prelude::*}; use quickcheck::*; use std::io::ErrorKind; @@ -394,9 +417,7 @@ mod tests { let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8]; data.extend(frame.clone().into_iter()); let mut framed = LengthDelimited::new(Cursor::new(data)); - let recved = futures::executor::block_on(async move { - framed.next().await - }).unwrap(); + let recved = futures::executor::block_on(async move { framed.next().await }).unwrap(); assert_eq!(recved.unwrap(), frame); } @@ -405,9 +426,7 @@ mod tests { let mut data = vec![0x81, 0x81, 0x1]; data.extend((0..16513).map(|_| 0)); let mut framed = LengthDelimited::new(Cursor::new(data)); - let recved = futures::executor::block_on(async move { - framed.next().await.unwrap() - }); + let recved = futures::executor::block_on(async move { framed.next().await.unwrap() }); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::InvalidData) @@ -479,7 +498,8 @@ mod tests { let expected_frames = frames.clone(); let server = async_std::task::spawn(async move { let socket = listener.accept().await.unwrap().0; - let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket)); + let mut connec = + rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket)); let mut buf = vec![0u8; 0]; for expected in expected_frames { diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 087b2a2cb21..00291f4ece8 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -94,10 +94,10 @@ mod negotiated; mod protocol; mod tests; -pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError}; -pub use self::protocol::ProtocolError; pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture}; pub use self::listener_select::{listener_select_proto, ListenerSelectFuture}; +pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError}; +pub use self::protocol::ProtocolError; /// Supported multistream-select versions. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -145,4 +145,4 @@ impl Default for Version { fn default() -> Self { Version::V1 } -} \ No newline at end of file +} diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index 7cf07c5fb02..aa433e40c4d 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -21,12 +21,18 @@ //! Protocol negotiation strategies for the peer acting as the listener //! in a multistream-select protocol negotiation. +use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}; use crate::{Negotiated, NegotiationError}; -use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine}; use futures::prelude::*; use smallvec::SmallVec; -use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}}; +use std::{ + convert::TryFrom as _, + iter::FromIterator, + mem, + pin::Pin, + task::{Context, Poll}, +}; /// Returns a `Future` that negotiates a protocol on the given I/O stream /// for a peer acting as the _listener_ (or _responder_). @@ -35,28 +41,29 @@ use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Conte /// computation that performs the protocol negotiation with the remote. The /// returned `Future` resolves with the name of the negotiated protocol and /// a [`Negotiated`] I/O stream. -pub fn listener_select_proto( - inner: R, - protocols: I, -) -> ListenerSelectFuture +pub fn listener_select_proto(inner: R, protocols: I) -> ListenerSelectFuture where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { - let protocols = protocols.into_iter().filter_map(|n| - match Protocol::try_from(n.as_ref()) { + let protocols = protocols + .into_iter() + .filter_map(|n| match Protocol::try_from(n.as_ref()) { Ok(p) => Some((n, p)), Err(e) => { - log::warn!("Listener: Ignoring invalid protocol: {} due to {}", - String::from_utf8_lossy(n.as_ref()), e); + log::warn!( + "Listener: Ignoring invalid protocol: {} due to {}", + String::from_utf8_lossy(n.as_ref()), + e + ); None } }); ListenerSelectFuture { protocols: SmallVec::from_iter(protocols), state: State::RecvHeader { - io: MessageIO::new(inner) + io: MessageIO::new(inner), }, last_sent_na: false, } @@ -80,19 +87,25 @@ pub struct ListenerSelectFuture { } enum State { - RecvHeader { io: MessageIO }, - SendHeader { io: MessageIO }, - RecvMessage { io: MessageIO }, + RecvHeader { + io: MessageIO, + }, + SendHeader { + io: MessageIO, + }, + RecvMessage { + io: MessageIO, + }, SendMessage { io: MessageIO, message: Message, - protocol: Option + protocol: Option, }, Flush { io: MessageIO, - protocol: Option + protocol: Option, }, - Done + Done, } impl Future for ListenerSelectFuture @@ -100,7 +113,7 @@ where // The Unpin bound here is required because we produce a `Negotiated` as the output. // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, - N: AsRef<[u8]> + Clone + N: AsRef<[u8]> + Clone, { type Output = Result<(N, Negotiated), NegotiationError>; @@ -111,14 +124,12 @@ where match mem::replace(this.state, State::Done) { State::RecvHeader { mut io } => { match io.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(Message::Header(h)))) => { - match h { - HeaderLine::V1 => *this.state = State::SendHeader { io } - } - } + Poll::Ready(Some(Ok(Message::Header(h)))) => match h { + HeaderLine::V1 => *this.state = State::SendHeader { io }, + }, Poll::Ready(Some(Ok(_))) => { return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) - }, + } Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))), // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -126,7 +137,7 @@ where Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), Poll::Pending => { *this.state = State::RecvHeader { io }; - return Poll::Pending + return Poll::Pending; } } } @@ -135,9 +146,9 @@ where match Pin::new(&mut io).poll_ready(cx) { Poll::Pending => { *this.state = State::SendHeader { io }; - return Poll::Pending - }, - Poll::Ready(Ok(())) => {}, + return Poll::Pending; + } + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } @@ -175,28 +186,37 @@ where // the dialer also raises `NegotiationError::Failed` when finally // reading the `N/A` response. if let ProtocolError::InvalidMessage = &err { - log::trace!("Listener: Negotiation failed with invalid \ - message after protocol rejection."); - return Poll::Ready(Err(NegotiationError::Failed)) + log::trace!( + "Listener: Negotiation failed with invalid \ + message after protocol rejection." + ); + return Poll::Ready(Err(NegotiationError::Failed)); } if let ProtocolError::IoError(e) = &err { if e.kind() == std::io::ErrorKind::UnexpectedEof { - log::trace!("Listener: Negotiation failed with EOF \ - after protocol rejection."); - return Poll::Ready(Err(NegotiationError::Failed)) + log::trace!( + "Listener: Negotiation failed with EOF \ + after protocol rejection." + ); + return Poll::Ready(Err(NegotiationError::Failed)); } } } - return Poll::Ready(Err(From::from(err))) + return Poll::Ready(Err(From::from(err))); } }; match msg { Message::ListProtocols => { - let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect(); + let supported = + this.protocols.iter().map(|(_, p)| p).cloned().collect(); let message = Message::Protocols(supported); - *this.state = State::SendMessage { io, message, protocol: None } + *this.state = State::SendMessage { + io, + message, + protocol: None, + } } Message::Protocol(p) => { let protocol = this.protocols.iter().find_map(|(name, proto)| { @@ -211,28 +231,42 @@ where log::debug!("Listener: confirming protocol: {}", p); Message::Protocol(p.clone()) } else { - log::debug!("Listener: rejecting protocol: {}", - String::from_utf8_lossy(p.as_ref())); + log::debug!( + "Listener: rejecting protocol: {}", + String::from_utf8_lossy(p.as_ref()) + ); Message::NotAvailable }; - *this.state = State::SendMessage { io, message, protocol }; + *this.state = State::SendMessage { + io, + message, + protocol, + }; } - _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } } - State::SendMessage { mut io, message, protocol } => { + State::SendMessage { + mut io, + message, + protocol, + } => { match Pin::new(&mut io).poll_ready(cx) { Poll::Pending => { - *this.state = State::SendMessage { io, message, protocol }; - return Poll::Pending - }, - Poll::Ready(Ok(())) => {}, + *this.state = State::SendMessage { + io, + message, + protocol, + }; + return Poll::Pending; + } + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } - if let Message::NotAvailable = &message { + if let Message::NotAvailable = &message { *this.last_sent_na = true; } else { *this.last_sent_na = false; @@ -249,26 +283,28 @@ where match Pin::new(&mut io).poll_flush(cx) { Poll::Pending => { *this.state = State::Flush { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } Poll::Ready(Ok(())) => { // If a protocol has been selected, finish negotiation. // Otherwise expect to receive another message. match protocol { Some(protocol) => { - log::debug!("Listener: sent confirmed protocol: {}", - String::from_utf8_lossy(protocol.as_ref())); + log::debug!( + "Listener: sent confirmed protocol: {}", + String::from_utf8_lossy(protocol.as_ref()) + ); let io = Negotiated::completed(io.into_inner()); - return Poll::Ready(Ok((protocol, io))) + return Poll::Ready(Ok((protocol, io))); } - None => *this.state = State::RecvMessage { io } + None => *this.state = State::RecvMessage { io }, } } Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } } - State::Done => panic!("State::poll called after completion") + State::Done => panic!("State::poll called after completion"), } } } diff --git a/misc/multistream-select/src/negotiated.rs b/misc/multistream-select/src/negotiated.rs index e80d579f2b4..2f78daf0376 100644 --- a/misc/multistream-select/src/negotiated.rs +++ b/misc/multistream-select/src/negotiated.rs @@ -18,11 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::protocol::{Protocol, MessageReader, Message, ProtocolError, HeaderLine}; +use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError}; -use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready}; +use futures::{ + io::{IoSlice, IoSliceMut}, + prelude::*, + ready, +}; use pin_project::pin_project; -use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}}; +use std::{ + error::Error, + fmt, io, mem, + pin::Pin, + task::{Context, Poll}, +}; /// An I/O stream that has settled on an (application-layer) protocol to use. /// @@ -39,7 +48,7 @@ use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}}; #[derive(Debug)] pub struct Negotiated { #[pin] - state: State + state: State, } /// A `Future` that waits on the completion of protocol negotiation. @@ -57,12 +66,15 @@ where type Output = Result, NegotiationError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut io = self.inner.take().expect("NegotiatedFuture called after completion."); + let mut io = self + .inner + .take() + .expect("NegotiatedFuture called after completion."); match Negotiated::poll(Pin::new(&mut io), cx) { Poll::Pending => { self.inner = Some(io); Poll::Pending - }, + } Poll::Ready(Ok(())) => Poll::Ready(Ok(io)), Poll::Ready(Err(err)) => { self.inner = Some(io); @@ -75,7 +87,9 @@ where impl Negotiated { /// Creates a `Negotiated` in state [`State::Completed`]. pub(crate) fn completed(io: TInner) -> Self { - Negotiated { state: State::Completed { io } } + Negotiated { + state: State::Completed { io }, + } } /// Creates a `Negotiated` in state [`State::Expecting`] that is still @@ -83,25 +97,31 @@ impl Negotiated { pub(crate) fn expecting( io: MessageReader, protocol: Protocol, - header: Option + header: Option, ) -> Self { - Negotiated { state: State::Expecting { io, protocol, header } } + Negotiated { + state: State::Expecting { + io, + protocol, + header, + }, + } } /// Polls the `Negotiated` for completion. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> where - TInner: AsyncRead + AsyncWrite + Unpin + TInner: AsyncRead + AsyncWrite + Unpin, { // Flush any pending negotiation data. match self.as_mut().poll_flush(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => { // If the remote closed the stream, it is important to still // continue reading the data that was sent, if any. if e.kind() != io::ErrorKind::WriteZero { - return Poll::Ready(Err(e.into())) + return Poll::Ready(Err(e.into())); } } } @@ -109,36 +129,52 @@ impl Negotiated { let mut this = self.project(); if let StateProj::Completed { .. } = this.state.as_mut().project() { - return Poll::Ready(Ok(())); + return Poll::Ready(Ok(())); } // Read outstanding protocol negotiation messages. loop { match mem::replace(&mut *this.state, State::Invalid) { - State::Expecting { mut io, header, protocol } => { + State::Expecting { + mut io, + header, + protocol, + } => { let msg = match Pin::new(&mut io).poll_next(cx)? { Poll::Ready(Some(msg)) => msg, Poll::Pending => { - *this.state = State::Expecting { io, header, protocol }; - return Poll::Pending - }, + *this.state = State::Expecting { + io, + header, + protocol, + }; + return Poll::Pending; + } Poll::Ready(None) => { return Poll::Ready(Err(ProtocolError::IoError( - io::ErrorKind::UnexpectedEof.into()).into())); + io::ErrorKind::UnexpectedEof.into(), + ) + .into())); } }; if let Message::Header(h) = &msg { if Some(h) == header.as_ref() { - *this.state = State::Expecting { io, protocol, header: None }; - continue + *this.state = State::Expecting { + io, + protocol, + header: None, + }; + continue; } } if let Message::Protocol(p) = &msg { if p.as_ref() == protocol.as_ref() { log::debug!("Negotiated: Received confirmation for protocol: {}", p); - *this.state = State::Completed { io: io.into_inner() }; + *this.state = State::Completed { + io: io.into_inner(), + }; return Poll::Ready(Ok(())); } } @@ -146,7 +182,7 @@ impl Negotiated { return Poll::Ready(Err(NegotiationError::Failed)); } - _ => panic!("Negotiated: Invalid state") + _ => panic!("Negotiated: Invalid state"), } } } @@ -178,7 +214,10 @@ enum State { /// In this state, a protocol has been agreed upon and I/O /// on the underlying stream can commence. - Completed { #[pin] io: R }, + Completed { + #[pin] + io: R, + }, /// Temporary state while moving the `io` resource from /// `Expecting` to `Completed`. @@ -187,11 +226,13 @@ enum State { impl AsyncRead for Negotiated where - TInner: AsyncRead + AsyncWrite + Unpin + TInner: AsyncRead + AsyncWrite + Unpin, { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) - -> Poll> - { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { loop { if let StateProj::Completed { io } = self.as_mut().project().state.project() { // If protocol negotiation is complete, commence with reading. @@ -201,7 +242,7 @@ where // Poll the `Negotiated`, driving protocol negotiation to completion, // including flushing of any remaining data. match self.as_mut().poll(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } @@ -217,19 +258,21 @@ where } }*/ - fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>]) - -> Poll> - { + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { loop { if let StateProj::Completed { io } = self.as_mut().project().state.project() { // If protocol negotiation is complete, commence with reading. - return io.poll_read_vectored(cx, bufs) + return io.poll_read_vectored(cx, bufs); } // Poll the `Negotiated`, driving protocol negotiation to completion, // including flushing of any remaining data. match self.as_mut().poll(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } @@ -239,9 +282,13 @@ where impl AsyncWrite for Negotiated where - TInner: AsyncWrite + AsyncRead + Unpin + TInner: AsyncWrite + AsyncRead + Unpin, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { match self.project().state.project() { StateProj::Completed { io } => io.poll_write(cx, buf), StateProj::Expecting { io, .. } => io.poll_write(cx, buf), @@ -261,7 +308,10 @@ where // Ensure all data has been flushed and expected negotiation messages // have been received. ready!(self.as_mut().poll(cx).map_err(Into::::into)?); - ready!(self.as_mut().poll_flush(cx).map_err(Into::::into)?); + ready!(self + .as_mut() + .poll_flush(cx) + .map_err(Into::::into)?); // Continue with the shutdown of the underlying I/O stream. match self.project().state.project() { @@ -271,9 +321,11 @@ where } } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) - -> Poll> - { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { match self.project().state.project() { StateProj::Completed { io } => io.poll_write_vectored(cx, bufs), StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs), @@ -307,7 +359,7 @@ impl From for NegotiationError { impl From for io::Error { fn from(err: NegotiationError) -> io::Error { if let NegotiationError::ProtocolError(e) = err { - return e.into() + return e.into(); } io::Error::new(io::ErrorKind::Other, err) } @@ -325,10 +377,10 @@ impl Error for NegotiationError { impl fmt::Display for NegotiationError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - NegotiationError::ProtocolError(p) => - fmt.write_fmt(format_args!("Protocol error: {}", p)), - NegotiationError::Failed => - fmt.write_str("Protocol negotiation failed.") + NegotiationError::ProtocolError(p) => { + fmt.write_fmt(format_args!("Protocol error: {}", p)) + } + NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."), } } } diff --git a/misc/multistream-select/src/protocol.rs b/misc/multistream-select/src/protocol.rs index af2e2825a3b..1cfdcc4b588 100644 --- a/misc/multistream-select/src/protocol.rs +++ b/misc/multistream-select/src/protocol.rs @@ -25,12 +25,18 @@ //! `Stream` and `Sink` implementations of `MessageIO` and //! `MessageReader`. -use crate::Version; use crate::length_delimited::{LengthDelimited, LengthDelimitedReader}; +use crate::Version; -use bytes::{Bytes, BytesMut, BufMut}; -use futures::{prelude::*, io::IoSlice, ready}; -use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{io::IoSlice, prelude::*, ready}; +use std::{ + convert::TryFrom, + error::Error, + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; use unsigned_varint as uvi; /// The maximum number of supported protocols that can be processed. @@ -75,7 +81,7 @@ impl TryFrom for Protocol { fn try_from(value: Bytes) -> Result { if !value.as_ref().starts_with(b"/") { - return Err(ProtocolError::InvalidProtocol) + return Err(ProtocolError::InvalidProtocol); } Ok(Protocol(value)) } @@ -160,7 +166,7 @@ impl Message { /// Decodes a `Message` from its byte representation. pub fn decode(mut msg: Bytes) -> Result { if msg == MSG_MULTISTREAM_1_0 { - return Ok(Message::Header(HeaderLine::V1)) + return Ok(Message::Header(HeaderLine::V1)); } if msg == MSG_PROTOCOL_NA { @@ -168,13 +174,14 @@ impl Message { } if msg == MSG_LS { - return Ok(Message::ListProtocols) + return Ok(Message::ListProtocols); } // If it starts with a `/`, ends with a line feed without any // other line feeds in-between, it must be a protocol name. - if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') && - !msg[.. msg.len() - 1].contains(&b'\n') + if msg.get(0) == Some(&b'/') + && msg.last() == Some(&b'\n') + && !msg[..msg.len() - 1].contains(&b'\n') { let p = Protocol::try_from(msg.split_to(msg.len() - 1))?; return Ok(Message::Protocol(p)); @@ -187,24 +194,24 @@ impl Message { loop { // A well-formed message must be terminated with a newline. if remaining == [b'\n'] { - break + break; } else if protocols.len() == MAX_PROTOCOLS { - return Err(ProtocolError::TooManyProtocols) + return Err(ProtocolError::TooManyProtocols); } // Decode the length of the next protocol name and check that // it ends with a line feed. let (len, tail) = uvi::decode::usize(remaining)?; if len == 0 || len > tail.len() || tail[len - 1] != b'\n' { - return Err(ProtocolError::InvalidMessage) + return Err(ProtocolError::InvalidMessage); } // Parse the protocol name. - let p = Protocol::try_from(Bytes::copy_from_slice(&tail[.. len - 1]))?; + let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?; protocols.push(p); // Skip ahead to the next protocol. - remaining = &tail[len ..]; + remaining = &tail[len..]; } Ok(Message::Protocols(protocols)) @@ -222,9 +229,11 @@ impl MessageIO { /// Constructs a new `MessageIO` resource wrapping the given I/O stream. pub fn new(inner: R) -> MessageIO where - R: AsyncRead + AsyncWrite + R: AsyncRead + AsyncWrite, { - Self { inner: LengthDelimited::new(inner) } + Self { + inner: LengthDelimited::new(inner), + } } /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the @@ -235,7 +244,9 @@ impl MessageIO { /// received but no more messages are written, allowing the writing of /// follow-up protocol data to commence. pub fn into_reader(self) -> MessageReader { - MessageReader { inner: self.inner.into_reader() } + MessageReader { + inner: self.inner.into_reader(), + } } /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream. @@ -265,7 +276,10 @@ where fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { let mut buf = BytesMut::new(); item.encode(&mut buf)?; - self.project().inner.start_send(buf.freeze()).map_err(From::from) + self.project() + .inner + .start_send(buf.freeze()) + .map_err(From::from) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -279,7 +293,7 @@ where impl Stream for MessageIO where - R: AsyncRead + R: AsyncRead, { type Item = Result; @@ -299,7 +313,7 @@ where #[derive(Debug)] pub struct MessageReader { #[pin] - inner: LengthDelimitedReader + inner: LengthDelimitedReader, } impl MessageReader { @@ -321,7 +335,7 @@ impl MessageReader { impl Stream for MessageReader where - R: AsyncRead + R: AsyncRead, { type Item = Result; @@ -332,9 +346,13 @@ where impl AsyncWrite for MessageReader where - TInner: AsyncWrite + TInner: AsyncWrite, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { self.project().inner.poll_write(cx, buf) } @@ -346,12 +364,19 @@ where self.project().inner.poll_close(cx) } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll> { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { self.project().inner.poll_write_vectored(cx, bufs) } } -fn poll_stream(stream: Pin<&mut S>, cx: &mut Context<'_>) -> Poll>> +fn poll_stream( + stream: Pin<&mut S>, + cx: &mut Context<'_>, +) -> Poll>> where S: Stream>, { @@ -361,7 +386,7 @@ where Err(err) => return Poll::Ready(Some(Err(err))), } } else { - return Poll::Ready(None) + return Poll::Ready(None); }; log::trace!("Received message: {:?}", msg); @@ -394,7 +419,7 @@ impl From for ProtocolError { impl From for io::Error { fn from(err: ProtocolError) -> Self { if let ProtocolError::IoError(e) = err { - return e + return e; } io::ErrorKind::InvalidData.into() } @@ -418,14 +443,10 @@ impl Error for ProtocolError { impl fmt::Display for ProtocolError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - ProtocolError::IoError(e) => - write!(fmt, "I/O error: {}", e), - ProtocolError::InvalidMessage => - write!(fmt, "Received an invalid message."), - ProtocolError::InvalidProtocol => - write!(fmt, "A protocol (name) is invalid."), - ProtocolError::TooManyProtocols => - write!(fmt, "Too many protocols received.") + ProtocolError::IoError(e) => write!(fmt, "I/O error: {}", e), + ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."), + ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."), + ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."), } } } @@ -434,8 +455,8 @@ impl fmt::Display for ProtocolError { mod tests { use super::*; use quickcheck::*; - use rand::Rng; use rand::distributions::Alphanumeric; + use rand::Rng; use std::iter; impl Arbitrary for Protocol { @@ -457,7 +478,7 @@ mod tests { 2 => Message::ListProtocols, 3 => Message::Protocol(Protocol::arbitrary(g)), 4 => Message::Protocols(Vec::arbitrary(g)), - _ => panic!() + _ => panic!(), } } } @@ -466,10 +487,11 @@ mod tests { fn encode_decode_message() { fn prop(msg: Message) { let mut buf = BytesMut::new(); - msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg)); + msg.encode(&mut buf) + .expect(&format!("Encoding message failed: {:?}", msg)); match Message::decode(buf.freeze()) { Ok(m) => assert_eq!(m, msg), - Err(e) => panic!("Decoding failed: {:?}", e) + Err(e) => panic!("Decoding failed: {:?}", e), } } quickcheck(prop as fn(_)) diff --git a/misc/multistream-select/src/tests.rs b/misc/multistream-select/src/tests.rs index f03d1b1ff75..ca627d24fcf 100644 --- a/misc/multistream-select/src/tests.rs +++ b/misc/multistream-select/src/tests.rs @@ -22,9 +22,9 @@ #![cfg(test)] -use crate::{Version, NegotiationError}; use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial}; use crate::{dialer_select_proto, listener_select_proto}; +use crate::{NegotiationError, Version}; use async_std::net::{TcpListener, TcpStream}; use futures::prelude::*; @@ -54,7 +54,8 @@ fn select_proto_basic() { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version) - .await.unwrap(); + .await + .unwrap(); assert_eq!(proto, b"/proto2"); io.write_all(b"ping").await.unwrap(); @@ -79,12 +80,14 @@ fn select_proto_basic() { fn negotiation_failed() { let _ = env_logger::try_init(); - async fn run(Test { - version, - listen_protos, - dial_protos, - dial_payload - }: Test) { + async fn run( + Test { + version, + listen_protos, + dial_protos, + dial_payload, + }: Test, + ) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener_addr = listener.local_addr().unwrap(); @@ -93,10 +96,12 @@ fn negotiation_failed() { let io = match listener_select_proto(connec, listen_protos).await { Ok((_, io)) => io, Err(NegotiationError::Failed) => return, - Err(NegotiationError::ProtocolError(e)) => panic!("Unexpected protocol error {}", e), + Err(NegotiationError::ProtocolError(e)) => { + panic!("Unexpected protocol error {}", e) + } }; match io.complete().await { - Err(NegotiationError::Failed) => {}, + Err(NegotiationError::Failed) => {} _ => panic!(), } }); @@ -106,14 +111,14 @@ fn negotiation_failed() { let mut io = match dialer_select_proto(connec, dial_protos.into_iter(), version).await { Err(NegotiationError::Failed) => return, Ok((_, io)) => io, - Err(_) => panic!() + Err(_) => panic!(), }; // The dialer may write a payload that is even sent before it // got confirmation of the last proposed protocol, when `V1Lazy` // is used. io.write_all(&dial_payload).await.unwrap(); match io.complete().await { - Err(NegotiationError::Failed) => {}, + Err(NegotiationError::Failed) => {} _ => panic!(), } }); @@ -135,10 +140,10 @@ fn negotiation_failed() { // // The choices here cover the main distinction between a single // and multiple protocols. - let protos = vec!{ + let protos = vec![ (vec!["/proto1"], vec!["/proto2"]), (vec!["/proto1", "/proto2"], vec!["/proto3", "/proto4"]), - }; + ]; // The payloads that the dialer sends after "successful" negotiation, // which may be sent even before the dialer got protocol confirmation @@ -147,7 +152,7 @@ fn negotiation_failed() { // The choices here cover the specific situations that can arise with // `V1Lazy` and which must nevertheless behave identically to `V1` w.r.t. // the outcome of the negotiation. - let payloads = vec!{ + let payloads = vec![ // No payload, in which case all versions should behave identically // in any case, i.e. the baseline test. vec![], @@ -155,13 +160,13 @@ fn negotiation_failed() { // `1` as a message length and encounters an invalid message (the // second `1`). The listener is nevertheless expected to fail // negotiation normally, just like with `V1`. - vec![1,1], + vec![1, 1], // With this payload and `V1Lazy`, the listener interprets the first // `42` as a message length and encounters unexpected EOF trying to // read a message of that length. The listener is nevertheless expected // to fail negotiation normally, just like with `V1` - vec![42,1], - }; + vec![42, 1], + ]; for (listen_protos, dial_protos) in protos { for dial_payload in payloads.clone() { @@ -195,7 +200,8 @@ fn select_proto_parallel() { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; let (proto, io) = dialer_select_proto_parallel(connec, protos.into_iter(), version) - .await.unwrap(); + .await + .unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); }); @@ -226,7 +232,8 @@ fn select_proto_serial() { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version) - .await.unwrap(); + .await + .unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); }); diff --git a/misc/multistream-select/tests/transport.rs b/misc/multistream-select/tests/transport.rs index 85f51834ac9..f42797f13ca 100644 --- a/misc/multistream-select/tests/transport.rs +++ b/misc/multistream-select/tests/transport.rs @@ -18,24 +18,23 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use futures::{channel::oneshot, prelude::*, ready}; use libp2p_core::{ connection::{ConnectionHandler, ConnectionHandlerEvent, Substream, SubstreamEndpoint}, identity, - muxing::StreamMuxerBox, - upgrade, multiaddr::Protocol, - Multiaddr, - Network, - network::{NetworkEvent, NetworkConfig}, - PeerId, - Transport, - transport::{self, MemoryTransport} + muxing::StreamMuxerBox, + network::{NetworkConfig, NetworkEvent}, + transport::{self, MemoryTransport}, + upgrade, Multiaddr, Network, PeerId, Transport, }; use libp2p_mplex::MplexConfig; use libp2p_plaintext::PlainText2Config; -use futures::{channel::oneshot, ready, prelude::*}; use rand::random; -use std::{io, task::{Context, Poll}}; +use std::{ + io, + task::{Context, Poll}, +}; type TestTransport = transport::Boxed<(PeerId, StreamMuxerBox)>; type TestNetwork = Network; @@ -43,11 +42,16 @@ type TestNetwork = Network; fn mk_transport(up: upgrade::Version) -> (PeerId, TestTransport) { let keys = identity::Keypair::generate_ed25519(); let id = keys.public().to_peer_id(); - (id, MemoryTransport::default() - .upgrade(up) - .authenticate(PlainText2Config { local_public_key: keys.public() }) - .multiplex(MplexConfig::default()) - .boxed()) + ( + id, + MemoryTransport::default() + .upgrade(up) + .authenticate(PlainText2Config { + local_public_key: keys.public(), + }) + .multiplex(MplexConfig::default()) + .boxed(), + ) } /// Tests the transport upgrade process with all supported @@ -63,7 +67,8 @@ fn transport_upgrade() { let listen_addr = Multiaddr::from(Protocol::Memory(random::())); let mut dialer = TestNetwork::new(dialer_transport, dialer_id, NetworkConfig::default()); - let mut listener = TestNetwork::new(listener_transport, listener_id, NetworkConfig::default()); + let mut listener = + TestNetwork::new(listener_transport, listener_id, NetworkConfig::default()); listener.listen_on(listen_addr).unwrap(); let (addr_sender, addr_receiver) = oneshot::channel(); @@ -71,33 +76,26 @@ fn transport_upgrade() { let client = async move { let addr = addr_receiver.await.unwrap(); dialer.dial(&addr, TestHandler()).unwrap(); - futures::future::poll_fn(move |cx| { - loop { - match ready!(dialer.poll(cx)) { - NetworkEvent::ConnectionEstablished { .. } => { - return Poll::Ready(()) - } - _ => {} - } + futures::future::poll_fn(move |cx| loop { + match ready!(dialer.poll(cx)) { + NetworkEvent::ConnectionEstablished { .. } => return Poll::Ready(()), + _ => {} } - }).await + }) + .await }; let mut addr_sender = Some(addr_sender); - let server = futures::future::poll_fn(move |cx| { - loop { - match ready!(listener.poll(cx)) { - NetworkEvent::NewListenerAddress { listen_addr, .. } => { - addr_sender.take().unwrap().send(listen_addr).unwrap(); - } - NetworkEvent::IncomingConnection { connection, .. } => { - listener.accept(connection, TestHandler()).unwrap(); - } - NetworkEvent::ConnectionEstablished { .. } => { - return Poll::Ready(()) - } - _ => {} + let server = futures::future::poll_fn(move |cx| loop { + match ready!(listener.poll(cx)) { + NetworkEvent::NewListenerAddress { listen_addr, .. } => { + addr_sender.take().unwrap().send(listen_addr).unwrap(); + } + NetworkEvent::IncomingConnection { connection, .. } => { + listener.accept(connection, TestHandler()).unwrap(); } + NetworkEvent::ConnectionEstablished { .. } => return Poll::Ready(()), + _ => {} } }); @@ -117,17 +115,21 @@ impl ConnectionHandler for TestHandler { type Substream = Substream; type OutboundOpenInfo = (); - fn inject_substream(&mut self, _: Self::Substream, _: SubstreamEndpoint) - {} + fn inject_substream( + &mut self, + _: Self::Substream, + _: SubstreamEndpoint, + ) { + } - fn inject_event(&mut self, _: Self::InEvent) - {} + fn inject_event(&mut self, _: Self::InEvent) {} - fn inject_address_change(&mut self, _: &Multiaddr) - {} + fn inject_address_change(&mut self, _: &Multiaddr) {} - fn poll(&mut self, _: &mut Context<'_>) - -> Poll, Self::Error>> + fn poll( + &mut self, + _: &mut Context<'_>, + ) -> Poll, Self::Error>> { Poll::Pending } diff --git a/misc/peer-id-generator/src/main.rs b/misc/peer-id-generator/src/main.rs index 6ac7af7e358..45239317396 100644 --- a/misc/peer-id-generator/src/main.rs +++ b/misc/peer-id-generator/src/main.rs @@ -26,22 +26,26 @@ fn main() { // bytes 0x1220, meaning that only some characters are valid. const ALLOWED_FIRST_BYTE: &'static [u8] = b"NPQRSTUVWXYZ"; - let prefix = - match env::args().nth(1) { - Some(prefix) => prefix, - None => { - println!( + let prefix = match env::args().nth(1) { + Some(prefix) => prefix, + None => { + println!( "Usage: {} \n\n\ Generates a peer id that starts with the chosen prefix using a secp256k1 public \ key.\n\n\ Prefix must be a sequence of characters in the base58 \ alphabet, and must start with one of the following: {}", - env::current_exe().unwrap().file_name().unwrap().to_str().unwrap(), + env::current_exe() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap(), str::from_utf8(ALLOWED_FIRST_BYTE).unwrap() ); - return; - } - }; + return; + } + }; // The base58 alphabet is not necessarily obvious. const ALPHABET: &'static [u8] = b"123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"; diff --git a/muxers/mplex/benches/split_send_size.rs b/muxers/mplex/benches/split_send_size.rs index c31703cb927..5380f21cc6b 100644 --- a/muxers/mplex/benches/split_send_size.rs +++ b/muxers/mplex/benches/split_send_size.rs @@ -24,9 +24,12 @@ use async_std::task; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use futures::channel::oneshot; -use futures::prelude::*; use futures::future::poll_fn; -use libp2p_core::{PeerId, Transport, StreamMuxer, identity, upgrade, transport, muxing, multiaddr::multiaddr, Multiaddr}; +use futures::prelude::*; +use libp2p_core::{ + identity, multiaddr::multiaddr, muxing, transport, upgrade, Multiaddr, PeerId, StreamMuxer, + Transport, +}; use libp2p_mplex as mplex; use libp2p_plaintext::PlainText2Config; use std::time::Duration; @@ -51,14 +54,13 @@ fn prepare(c: &mut Criterion) { let payload: Vec = vec![1; 1024 * 1024 * 1]; let mut tcp = c.benchmark_group("tcp"); - let tcp_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127,0,0,1)), Tcp(0u16)]; + let tcp_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1)), Tcp(0u16)]; for &size in BENCH_SIZES.iter() { tcp.throughput(Throughput::Bytes(payload.len() as u64)); let trans = tcp_transport(size); - tcp.bench_function( - format!("{}", size), - |b| b.iter(|| run(black_box(&trans), black_box(&payload), black_box(&tcp_addr))) - ); + tcp.bench_function(format!("{}", size), |b| { + b.iter(|| run(black_box(&trans), black_box(&payload), black_box(&tcp_addr))) + }); } tcp.finish(); @@ -67,15 +69,13 @@ fn prepare(c: &mut Criterion) { for &size in BENCH_SIZES.iter() { mem.throughput(Throughput::Bytes(payload.len() as u64)); let trans = mem_transport(size); - mem.bench_function( - format!("{}", size), - |b| b.iter(|| run(black_box(&trans), black_box(&payload), black_box(&mem_addr))) - ); + mem.bench_function(format!("{}", size), |b| { + b.iter(|| run(black_box(&trans), black_box(&payload), black_box(&mem_addr))) + }); } mem.finish(); } - /// Transfers the given payload between two nodes using the given transport. fn run(transport: &BenchTransport, payload: &Vec, listen_addr: &Multiaddr) { let mut listener = transport.clone().listen_on(listen_addr.clone()).unwrap(); @@ -101,18 +101,20 @@ fn run(transport: &BenchTransport, payload: &Vec, listen_addr: &Multiaddr) { let end = off + std::cmp::min(buf.len() - off, 8 * 1024); let n = poll_fn(|cx| { conn.read_substream(cx, &mut s, &mut buf[off..end]) - }).await.unwrap(); + }) + .await + .unwrap(); off += n; if off == buf.len() { - return + return; } } } Ok(_) => panic!("Unexpected muxer event"), - Err(e) => panic!("Unexpected error: {:?}", e) + Err(e) => panic!("Unexpected error: {:?}", e), } } - _ => panic!("Unexpected listener event") + _ => panic!("Unexpected listener event"), } } }); @@ -122,16 +124,20 @@ fn run(transport: &BenchTransport, payload: &Vec, listen_addr: &Multiaddr) { let addr = addr_receiver.await.unwrap(); let (_peer, conn) = transport.clone().dial(addr).unwrap().await.unwrap(); let mut handle = conn.open_outbound(); - let mut stream = poll_fn(|cx| conn.poll_outbound(cx, &mut handle)).await.unwrap(); + let mut stream = poll_fn(|cx| conn.poll_outbound(cx, &mut handle)) + .await + .unwrap(); let mut off = 0; loop { - let n = poll_fn(|cx| { - conn.write_substream(cx, &mut stream, &payload[off..]) - }).await.unwrap(); + let n = poll_fn(|cx| conn.write_substream(cx, &mut stream, &payload[off..])) + .await + .unwrap(); off += n; if off == payload.len() { - poll_fn(|cx| conn.flush_substream(cx, &mut stream)).await.unwrap(); - return + poll_fn(|cx| conn.flush_substream(cx, &mut stream)) + .await + .unwrap(); + return; } } }); @@ -147,7 +153,8 @@ fn tcp_transport(split_send_size: usize) -> BenchTransport { let mut mplex = mplex::MplexConfig::default(); mplex.set_split_send_size(split_send_size); - libp2p_tcp::TcpConfig::new().nodelay(true) + libp2p_tcp::TcpConfig::new() + .nodelay(true) .upgrade(upgrade::Version::V1) .authenticate(PlainText2Config { local_public_key }) .multiplex(mplex) diff --git a/muxers/mplex/src/codec.rs b/muxers/mplex/src/codec.rs index f56bb146ad0..3867cd27d8d 100644 --- a/muxers/mplex/src/codec.rs +++ b/muxers/mplex/src/codec.rs @@ -18,10 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::{BufMut, Bytes, BytesMut}; use asynchronous_codec::{Decoder, Encoder}; +use bytes::{BufMut, Bytes, BytesMut}; use libp2p_core::Endpoint; -use std::{fmt, hash::{Hash, Hasher}, io, mem}; +use std::{ + fmt, + hash::{Hash, Hasher}, + io, mem, +}; use unsigned_varint::{codec, encode}; // Maximum size for a packet: 1MB as per the spec. @@ -82,18 +86,27 @@ pub struct RemoteStreamId { impl LocalStreamId { pub fn dialer(num: u64) -> Self { - Self { num, role: Endpoint::Dialer } + Self { + num, + role: Endpoint::Dialer, + } } #[cfg(test)] pub fn listener(num: u64) -> Self { - Self { num, role: Endpoint::Listener } + Self { + num, + role: Endpoint::Listener, + } } pub fn next(self) -> Self { Self { - num: self.num.checked_add(1).expect("Mplex substream ID overflowed"), - .. self + num: self + .num + .checked_add(1) + .expect("Mplex substream ID overflowed"), + ..self } } @@ -108,11 +121,17 @@ impl LocalStreamId { impl RemoteStreamId { fn dialer(num: u64) -> Self { - Self { num, role: Endpoint::Dialer } + Self { + num, + role: Endpoint::Dialer, + } } fn listener(num: u64) -> Self { - Self { num, role: Endpoint::Listener } + Self { + num, + role: Endpoint::Listener, + } } /// Converts this `RemoteStreamId` into the corresponding `LocalStreamId` @@ -174,31 +193,28 @@ impl Decoder for Codec { fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { loop { match mem::replace(&mut self.decoder_state, CodecDecodeState::Poisoned) { - CodecDecodeState::Begin => { - match self.varint_decoder.decode(src)? { - Some(header) => { - self.decoder_state = CodecDecodeState::HasHeader(header); - }, - None => { - self.decoder_state = CodecDecodeState::Begin; - return Ok(None); - }, + CodecDecodeState::Begin => match self.varint_decoder.decode(src)? { + Some(header) => { + self.decoder_state = CodecDecodeState::HasHeader(header); + } + None => { + self.decoder_state = CodecDecodeState::Begin; + return Ok(None); } }, - CodecDecodeState::HasHeader(header) => { - match self.varint_decoder.decode(src)? { - Some(len) => { - if len as usize > MAX_FRAME_SIZE { - let msg = format!("Mplex frame length {} exceeds maximum", len); - return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); - } - - self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len as usize); - }, - None => { - self.decoder_state = CodecDecodeState::HasHeader(header); - return Ok(None); - }, + CodecDecodeState::HasHeader(header) => match self.varint_decoder.decode(src)? { + Some(len) => { + if len as usize > MAX_FRAME_SIZE { + let msg = format!("Mplex frame length {} exceeds maximum", len); + return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); + } + + self.decoder_state = + CodecDecodeState::HasHeaderAndLen(header, len as usize); + } + None => { + self.decoder_state = CodecDecodeState::HasHeader(header); + return Ok(None); } }, CodecDecodeState::HasHeaderAndLen(header, len) => { @@ -212,25 +228,44 @@ impl Decoder for Codec { let buf = src.split_to(len); let num = (header >> 3) as u64; let out = match header & 7 { - 0 => Frame::Open { stream_id: RemoteStreamId::dialer(num) }, - 1 => Frame::Data { stream_id: RemoteStreamId::listener(num), data: buf.freeze() }, - 2 => Frame::Data { stream_id: RemoteStreamId::dialer(num), data: buf.freeze() }, - 3 => Frame::Close { stream_id: RemoteStreamId::listener(num) }, - 4 => Frame::Close { stream_id: RemoteStreamId::dialer(num) }, - 5 => Frame::Reset { stream_id: RemoteStreamId::listener(num) }, - 6 => Frame::Reset { stream_id: RemoteStreamId::dialer(num) }, + 0 => Frame::Open { + stream_id: RemoteStreamId::dialer(num), + }, + 1 => Frame::Data { + stream_id: RemoteStreamId::listener(num), + data: buf.freeze(), + }, + 2 => Frame::Data { + stream_id: RemoteStreamId::dialer(num), + data: buf.freeze(), + }, + 3 => Frame::Close { + stream_id: RemoteStreamId::listener(num), + }, + 4 => Frame::Close { + stream_id: RemoteStreamId::dialer(num), + }, + 5 => Frame::Reset { + stream_id: RemoteStreamId::listener(num), + }, + 6 => Frame::Reset { + stream_id: RemoteStreamId::dialer(num), + }, _ => { let msg = format!("Invalid mplex header value 0x{:x}", header); return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); - }, + } }; self.decoder_state = CodecDecodeState::Begin; return Ok(Some(out)); - }, + } CodecDecodeState::Poisoned => { - return Err(io::Error::new(io::ErrorKind::InvalidData, "Mplex codec poisoned")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Mplex codec poisoned", + )); } } } @@ -243,27 +278,51 @@ impl Encoder for Codec { fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { let (header, data) = match item { - Frame::Open { stream_id } => { - (stream_id.num << 3, Bytes::new()) - }, - Frame::Data { stream_id: LocalStreamId { num, role: Endpoint::Listener }, data } => { - (num << 3 | 1, data) - }, - Frame::Data { stream_id: LocalStreamId { num, role: Endpoint::Dialer }, data } => { - (num << 3 | 2, data) - }, - Frame::Close { stream_id: LocalStreamId { num, role: Endpoint::Listener } } => { - (num << 3 | 3, Bytes::new()) - }, - Frame::Close { stream_id: LocalStreamId { num, role: Endpoint::Dialer } } => { - (num << 3 | 4, Bytes::new()) - }, - Frame::Reset { stream_id: LocalStreamId { num, role: Endpoint::Listener } } => { - (num << 3 | 5, Bytes::new()) - }, - Frame::Reset { stream_id: LocalStreamId { num, role: Endpoint::Dialer } } => { - (num << 3 | 6, Bytes::new()) - }, + Frame::Open { stream_id } => (stream_id.num << 3, Bytes::new()), + Frame::Data { + stream_id: + LocalStreamId { + num, + role: Endpoint::Listener, + }, + data, + } => (num << 3 | 1, data), + Frame::Data { + stream_id: + LocalStreamId { + num, + role: Endpoint::Dialer, + }, + data, + } => (num << 3 | 2, data), + Frame::Close { + stream_id: + LocalStreamId { + num, + role: Endpoint::Listener, + }, + } => (num << 3 | 3, Bytes::new()), + Frame::Close { + stream_id: + LocalStreamId { + num, + role: Endpoint::Dialer, + }, + } => (num << 3 | 4, Bytes::new()), + Frame::Reset { + stream_id: + LocalStreamId { + num, + role: Endpoint::Listener, + }, + } => (num << 3 | 5, Bytes::new()), + Frame::Reset { + stream_id: + LocalStreamId { + num, + role: Endpoint::Dialer, + }, + } => (num << 3 | 6, Bytes::new()), }; let mut header_buf = encode::u64_buffer(); @@ -274,7 +333,10 @@ impl Encoder for Codec { let data_len_bytes = encode::usize(data_len, &mut data_buf); if data_len > MAX_FRAME_SIZE { - return Err(io::Error::new(io::ErrorKind::InvalidData, "data size exceed maximum")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "data size exceed maximum", + )); } dst.reserve(header_bytes.len() + data_len_bytes.len() + data_len); @@ -294,15 +356,21 @@ mod tests { let mut enc = Codec::new(); let role = Endpoint::Dialer; let data = Bytes::from(&[123u8; MAX_FRAME_SIZE + 1][..]); - let bad_msg = Frame::Data { stream_id: LocalStreamId { num: 123, role }, data }; + let bad_msg = Frame::Data { + stream_id: LocalStreamId { num: 123, role }, + data, + }; let mut out = BytesMut::new(); match enc.encode(bad_msg, &mut out) { Err(e) => assert_eq!(e.to_string(), "data size exceed maximum"), - _ => panic!("Can't send a message bigger than MAX_FRAME_SIZE") + _ => panic!("Can't send a message bigger than MAX_FRAME_SIZE"), } let data = Bytes::from(&[123u8; MAX_FRAME_SIZE][..]); - let ok_msg = Frame::Data { stream_id: LocalStreamId { num: 123, role }, data }; + let ok_msg = Frame::Data { + stream_id: LocalStreamId { num: 123, role }, + data, + }; assert!(enc.encode(ok_msg, &mut out).is_ok()); } @@ -311,19 +379,24 @@ mod tests { // Create new codec object for encoding and decoding our frame. let mut codec = Codec::new(); // Create a u64 stream ID. - let id: u64 = u32::MAX as u64 + 1 ; - let stream_id = LocalStreamId { num: id, role: Endpoint::Dialer }; + let id: u64 = u32::MAX as u64 + 1; + let stream_id = LocalStreamId { + num: id, + role: Endpoint::Dialer, + }; // Open a new frame with that stream ID. let original_frame = Frame::Open { stream_id }; // Encode that frame. let mut enc_frame = BytesMut::new(); - codec.encode(original_frame, &mut enc_frame) + codec + .encode(original_frame, &mut enc_frame) .expect("Encoding to succeed."); // Decode encoded frame and extract stream ID. - let dec_string_id = codec.decode(&mut enc_frame) + let dec_string_id = codec + .decode(&mut enc_frame) .expect("Decoding to succeed.") .map(|f| f.remote_id()) .unwrap(); diff --git a/muxers/mplex/src/io.rs b/muxers/mplex/src/io.rs index e4e49935b8d..80da197a965 100644 --- a/muxers/mplex/src/io.rs +++ b/muxers/mplex/src/io.rs @@ -18,20 +18,24 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::Bytes; -use crate::{MplexConfig, MaxBufferBehaviour}; use crate::codec::{Codec, Frame, LocalStreamId, RemoteStreamId}; -use log::{debug, trace}; -use futures::{prelude::*, ready, stream::Fuse}; -use futures::task::{AtomicWaker, ArcWake, waker_ref, WakerRef}; +use crate::{MaxBufferBehaviour, MplexConfig}; use asynchronous_codec::Framed; +use bytes::Bytes; +use futures::task::{waker_ref, ArcWake, AtomicWaker, WakerRef}; +use futures::{prelude::*, ready, stream::Fuse}; +use log::{debug, trace}; use nohash_hasher::{IntMap, IntSet}; use parking_lot::Mutex; use smallvec::SmallVec; use std::collections::VecDeque; -use std::{cmp, fmt, io, mem, sync::Arc, task::{Context, Poll, Waker}}; +use std::{ + cmp, fmt, io, mem, + sync::Arc, + task::{Context, Poll, Waker}, +}; -pub use std::io::{Result, Error, ErrorKind}; +pub use std::io::{Error, ErrorKind, Result}; /// A connection identifier. /// @@ -109,7 +113,7 @@ enum Status { impl Multiplexed where - C: AsyncRead + AsyncWrite + Unpin + C: AsyncRead + AsyncWrite + Unpin, { /// Creates a new multiplexed I/O stream. pub fn new(io: C, config: MplexConfig) -> Self { @@ -134,8 +138,8 @@ where pending: Mutex::new(Default::default()), }), notifier_open: NotifierOpen { - pending: Default::default() - } + pending: Default::default(), + }, } } @@ -223,14 +227,14 @@ where // from the respective substreams. if num_buffered == self.config.max_buffer_len { cx.waker().clone().wake(); - return Poll::Pending + return Poll::Pending; } // Wait for the next inbound `Open` frame. match ready!(self.poll_read_frame(cx, None))? { Frame::Open { stream_id } => { if let Some(id) = self.on_open(stream_id)? { - return Poll::Ready(Ok(id)) + return Poll::Ready(Ok(id)); } } Frame::Data { stream_id, data } => { @@ -240,9 +244,7 @@ where Frame::Close { stream_id } => { self.on_close(stream_id.into_local()); } - Frame::Reset { stream_id } => { - self.on_reset(stream_id.into_local()) - } + Frame::Reset { stream_id } => self.on_reset(stream_id.into_local()), } } } @@ -253,10 +255,12 @@ where // Check the stream limits. if self.substreams.len() >= self.config.max_substreams { - debug!("{}: Maximum number of substreams reached ({})", - self.id, self.config.max_substreams); + debug!( + "{}: Maximum number of substreams reached ({})", + self.id, self.config.max_substreams + ); self.notifier_open.register(cx.waker()); - return Poll::Pending + return Poll::Pending; } // Send the `Open` frame. @@ -267,11 +271,18 @@ where let frame = Frame::Open { stream_id }; match self.io.start_send_unpin(frame) { Ok(()) => { - self.substreams.insert(stream_id, SubstreamState::Open { - buf: Default::default() - }); - debug!("{}: New outbound substream: {} (total {})", - self.id, stream_id, self.substreams.len()); + self.substreams.insert( + stream_id, + SubstreamState::Open { + buf: Default::default(), + }, + ); + debug!( + "{}: New outbound substream: {} (total {})", + self.id, + stream_id, + self.substreams.len() + ); // The flush is delayed and the `Open` frame may be sent // together with other frames in the same transport packet. self.pending_flush_open.insert(stream_id); @@ -279,8 +290,8 @@ where } Err(e) => Poll::Ready(self.on_error(e)), } - }, - Err(e) => Poll::Ready(self.on_error(e)) + } + Err(e) => Poll::Ready(self.on_error(e)), } } @@ -310,7 +321,7 @@ where // Check if the underlying stream is ok. match self.status { Status::Closed | Status::Err(_) => return, - Status::Open => {}, + Status::Open => {} } // If there is still a task waker interested in reading from that @@ -321,7 +332,7 @@ where // Remove the substream, scheduling pending frames as necessary. match self.substreams.remove(&id) { - None => {}, + None => {} Some(state) => { // If we fell below the substream limit, notify tasks that had // interest in opening an outbound substream earlier. @@ -336,17 +347,19 @@ where SubstreamState::Reset { .. } => {} SubstreamState::RecvClosed { .. } => { if self.check_max_pending_frames().is_err() { - return + return; } trace!("{}: Pending close for stream {}", self.id, id); - self.pending_frames.push_front(Frame::Close { stream_id: id }); + self.pending_frames + .push_front(Frame::Close { stream_id: id }); } SubstreamState::Open { .. } => { if self.check_max_pending_frames().is_err() { - return + return; } trace!("{}: Pending reset for stream {}", self.id, id); - self.pending_frames.push_front(Frame::Reset { stream_id: id }); + self.pending_frames + .push_front(Frame::Reset { stream_id: id }); } } } @@ -354,17 +367,22 @@ where } /// Writes data to a substream. - pub fn poll_write_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId, buf: &[u8]) - -> Poll> - { + pub fn poll_write_stream( + &mut self, + cx: &mut Context<'_>, + id: LocalStreamId, + buf: &[u8], + ) -> Poll> { self.guard_open()?; // Check if the stream is open for writing. match self.substreams.get(&id) { - None | Some(SubstreamState::Reset { .. }) => - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())), - Some(SubstreamState::SendClosed { .. }) | Some(SubstreamState::Closed { .. }) => - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + None | Some(SubstreamState::Reset { .. }) => { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } + Some(SubstreamState::SendClosed { .. }) | Some(SubstreamState::Closed { .. }) => { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) + } Some(SubstreamState::Open { .. }) | Some(SubstreamState::RecvClosed { .. }) => { // Substream is writeable. Continue. } @@ -375,8 +393,11 @@ where // Send the data frame. ready!(self.poll_send_frame(cx, || { - let data = Bytes::copy_from_slice(&buf[.. frame_len]); - Frame::Data { stream_id: id, data } + let data = Bytes::copy_from_slice(&buf[..frame_len]); + Frame::Data { + stream_id: id, + data, + } }))?; Poll::Ready(Ok(frame_len)) @@ -396,9 +417,11 @@ where /// and under consideration of the number of already used substreams, /// thereby waking the task that last called `poll_next_stream`, if any. /// Inbound substreams received in excess of that limit are immediately reset. - pub fn poll_read_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId) - -> Poll>> - { + pub fn poll_read_stream( + &mut self, + cx: &mut Context<'_>, + id: LocalStreamId, + ) -> Poll>> { self.guard_open()?; // Try to read from the buffer first. @@ -411,7 +434,7 @@ where ArcWake::wake_by_ref(&self.notifier_read); } let data = buf.remove(0); - return Poll::Ready(Ok(Some(data))) + return Poll::Ready(Ok(Some(data))); } // If the stream buffer "spilled" onto the heap, free that memory. buf.shrink_to_fit(); @@ -426,7 +449,7 @@ where // a chance to read from the other substream(s). if num_buffered == self.config.max_buffer_len { cx.waker().clone().wake(); - return Poll::Pending + return Poll::Pending; } // Check if the targeted substream (if any) reached EOF. @@ -436,14 +459,14 @@ where // remote, as the `StreamMuxer::read_substream` contract only // permits errors on "terminal" conditions, e.g. if the connection // has been closed or on protocol misbehaviour. - return Poll::Ready(Ok(None)) + return Poll::Ready(Ok(None)); } // Read the next frame. match ready!(self.poll_read_frame(cx, Some(id)))? { Frame::Data { data, stream_id } if stream_id.into_local() == id => { return Poll::Ready(Ok(Some(data))) - }, + } Frame::Data { stream_id, data } => { // The data frame is for a different stream than the one // currently being polled, so it needs to be buffered and @@ -454,7 +477,12 @@ where frame @ Frame::Open { .. } => { if let Some(id) = self.on_open(frame.remote_id())? { self.open_buffer.push_front(id); - trace!("{}: Buffered new inbound stream {} (total: {})", self.id, id, self.open_buffer.len()); + trace!( + "{}: Buffered new inbound stream {} (total: {})", + self.id, + id, + self.open_buffer.len() + ); self.notifier_read.wake_next_stream(); } } @@ -462,14 +490,14 @@ where let stream_id = stream_id.into_local(); self.on_close(stream_id); if id == stream_id { - return Poll::Ready(Ok(None)) + return Poll::Ready(Ok(None)); } } Frame::Reset { stream_id } => { let stream_id = stream_id.into_local(); self.on_reset(stream_id); if id == stream_id { - return Poll::Ready(Ok(None)) + return Poll::Ready(Ok(None)); } } } @@ -481,9 +509,11 @@ where /// > **Note**: This is equivalent to `poll_flush()`, i.e. to flushing /// > all substreams, except that this operation returns an error if /// > the underlying I/O stream is already closed. - pub fn poll_flush_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId) - -> Poll> - { + pub fn poll_flush_stream( + &mut self, + cx: &mut Context<'_>, + id: LocalStreamId, + ) -> Poll> { self.guard_open()?; ready!(self.poll_flush(cx))?; @@ -495,15 +525,18 @@ where /// Closes a stream for writing. /// /// > **Note**: As opposed to `poll_close()`, a flush it not implied. - pub fn poll_close_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId) - -> Poll> - { + pub fn poll_close_stream( + &mut self, + cx: &mut Context<'_>, + id: LocalStreamId, + ) -> Poll> { self.guard_open()?; match self.substreams.remove(&id) { None => Poll::Ready(Ok(())), Some(SubstreamState::SendClosed { buf }) => { - self.substreams.insert(id, SubstreamState::SendClosed { buf }); + self.substreams + .insert(id, SubstreamState::SendClosed { buf }); Poll::Ready(Ok(())) } Some(SubstreamState::Closed { buf }) => { @@ -515,18 +548,26 @@ where Poll::Ready(Ok(())) } Some(SubstreamState::Open { buf }) => { - if self.poll_send_frame(cx, || Frame::Close { stream_id: id })?.is_pending() { + if self + .poll_send_frame(cx, || Frame::Close { stream_id: id })? + .is_pending() + { self.substreams.insert(id, SubstreamState::Open { buf }); Poll::Pending } else { debug!("{}: Closed substream {} (half-close)", self.id, id); - self.substreams.insert(id, SubstreamState::SendClosed { buf }); + self.substreams + .insert(id, SubstreamState::SendClosed { buf }); Poll::Ready(Ok(())) } } Some(SubstreamState::RecvClosed { buf }) => { - if self.poll_send_frame(cx, || Frame::Close { stream_id: id })?.is_pending() { - self.substreams.insert(id, SubstreamState::RecvClosed { buf }); + if self + .poll_send_frame(cx, || Frame::Close { stream_id: id })? + .is_pending() + { + self.substreams + .insert(id, SubstreamState::RecvClosed { buf }); Poll::Pending } else { debug!("{}: Closed substream {}", self.id, id); @@ -541,10 +582,9 @@ where /// /// The frame is only constructed if the underlying sink is ready to /// send another frame. - fn poll_send_frame(&mut self, cx: &mut Context<'_>, frame: F) - -> Poll> + fn poll_send_frame(&mut self, cx: &mut Context<'_>, frame: F) -> Poll> where - F: FnOnce() -> Frame + F: FnOnce() -> Frame, { let waker = NotifierWrite::register(&self.notifier_write, cx.waker()); match ready!(self.io.poll_ready_unpin(&mut Context::from_waker(&waker))) { @@ -553,10 +593,10 @@ where trace!("{}: Sending {:?}", self.id, frame); match self.io.start_send_unpin(frame) { Ok(()) => Poll::Ready(Ok(())), - Err(e) => Poll::Ready(self.on_error(e)) + Err(e) => Poll::Ready(self.on_error(e)), } - }, - Err(e) => Poll::Ready(self.on_error(e)) + } + Err(e) => Poll::Ready(self.on_error(e)), } } @@ -566,12 +606,14 @@ where /// the current task is interested and wants to be woken up for, /// in case new frames can be read. `None` means interest in /// frames for any substream. - fn poll_read_frame(&mut self, cx: &mut Context<'_>, stream_id: Option) - -> Poll>> - { + fn poll_read_frame( + &mut self, + cx: &mut Context<'_>, + stream_id: Option, + ) -> Poll>> { // Try to send pending frames, if there are any, without blocking, if let Poll::Ready(Err(e)) = self.send_pending_frames(cx) { - return Poll::Ready(Err(e)) + return Poll::Ready(Err(e)); } // Perform any pending flush before reading. @@ -593,13 +635,19 @@ where if !self.notifier_read.wake_read_stream(*blocked_id) { // No task dedicated to the blocked stream woken, so schedule // this task again to have a chance at progress. - trace!("{}: No task to read from blocked stream. Waking current task.", self.id); + trace!( + "{}: No task to read from blocked stream. Waking current task.", + self.id + ); cx.waker().clone().wake(); } else if let Some(id) = stream_id { // We woke some other task, but are still interested in // reading `Data` frames from the current stream when unblocked. - debug_assert!(blocked_id != &id, "Unexpected attempt at reading a new \ - frame from a substream with a full buffer."); + debug_assert!( + blocked_id != &id, + "Unexpected attempt at reading a new \ + frame from a substream with a full buffer." + ); let _ = NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id); } else { // We woke some other task but are still interested in @@ -607,13 +655,13 @@ where let _ = NotifierRead::register_next_stream(&self.notifier_read, cx.waker()); } - return Poll::Pending + return Poll::Pending; } // Try to read another frame from the underlying I/O stream. let waker = match stream_id { Some(id) => NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id), - None => NotifierRead::register_next_stream(&self.notifier_read, cx.waker()) + None => NotifierRead::register_next_stream(&self.notifier_read, cx.waker()), }; match ready!(self.io.poll_next_unpin(&mut Context::from_waker(&waker))) { Some(Ok(frame)) => { @@ -621,7 +669,7 @@ where Poll::Ready(Ok(frame)) } Some(Err(e)) => Poll::Ready(self.on_error(e)), - None => Poll::Ready(self.on_error(io::ErrorKind::UnexpectedEof.into())) + None => Poll::Ready(self.on_error(io::ErrorKind::UnexpectedEof.into())), } } @@ -630,27 +678,41 @@ where let id = id.into_local(); if self.substreams.contains_key(&id) { - debug!("{}: Received unexpected `Open` frame for open substream {}", self.id, id); - return self.on_error(io::Error::new(io::ErrorKind::Other, - "Protocol error: Received `Open` frame for open substream.")) + debug!( + "{}: Received unexpected `Open` frame for open substream {}", + self.id, id + ); + return self.on_error(io::Error::new( + io::ErrorKind::Other, + "Protocol error: Received `Open` frame for open substream.", + )); } if self.substreams.len() >= self.config.max_substreams { - debug!("{}: Maximum number of substreams exceeded: {}", - self.id, self.config.max_substreams); + debug!( + "{}: Maximum number of substreams exceeded: {}", + self.id, self.config.max_substreams + ); self.check_max_pending_frames()?; debug!("{}: Pending reset for new stream {}", self.id, id); - self.pending_frames.push_front(Frame::Reset { - stream_id: id - }); - return Ok(None) + self.pending_frames + .push_front(Frame::Reset { stream_id: id }); + return Ok(None); } - self.substreams.insert(id, SubstreamState::Open { - buf: Default::default() - }); + self.substreams.insert( + id, + SubstreamState::Open { + buf: Default::default(), + }, + ); - debug!("{}: New inbound substream: {} (total {})", self.id, id, self.substreams.len()); + debug!( + "{}: New inbound substream: {} (total {})", + self.id, + id, + self.substreams.len() + ); Ok(Some(id)) } @@ -660,15 +722,22 @@ where if let Some(state) = self.substreams.remove(&id) { match state { SubstreamState::Closed { .. } => { - trace!("{}: Ignoring reset for mutually closed substream {}.", self.id, id); + trace!( + "{}: Ignoring reset for mutually closed substream {}.", + self.id, + id + ); } SubstreamState::Reset { .. } => { - trace!("{}: Ignoring redundant reset for already reset substream {}", - self.id, id); + trace!( + "{}: Ignoring redundant reset for already reset substream {}", + self.id, + id + ); } - SubstreamState::RecvClosed { buf } | - SubstreamState::SendClosed { buf } | - SubstreamState::Open { buf } => { + SubstreamState::RecvClosed { buf } + | SubstreamState::SendClosed { buf } + | SubstreamState::Open { buf } => { debug!("{}: Substream {} reset by remote.", self.id, id); self.substreams.insert(id, SubstreamState::Reset { buf }); // Notify tasks interested in reading from that stream, @@ -677,8 +746,11 @@ where } } } else { - trace!("{}: Ignoring `Reset` for unknown substream {}. Possibly dropped earlier.", - self.id, id); + trace!( + "{}: Ignoring `Reset` for unknown substream {}. Possibly dropped earlier.", + self.id, + id + ); } } @@ -687,33 +759,45 @@ where if let Some(state) = self.substreams.remove(&id) { match state { SubstreamState::RecvClosed { .. } | SubstreamState::Closed { .. } => { - debug!("{}: Ignoring `Close` frame for closed substream {}", - self.id, id); + debug!( + "{}: Ignoring `Close` frame for closed substream {}", + self.id, id + ); self.substreams.insert(id, state); - }, + } SubstreamState::Reset { buf } => { - debug!("{}: Ignoring `Close` frame for already reset substream {}", - self.id, id); + debug!( + "{}: Ignoring `Close` frame for already reset substream {}", + self.id, id + ); self.substreams.insert(id, SubstreamState::Reset { buf }); } SubstreamState::SendClosed { buf } => { - debug!("{}: Substream {} closed by remote (SendClosed -> Closed).", - self.id, id); + debug!( + "{}: Substream {} closed by remote (SendClosed -> Closed).", + self.id, id + ); self.substreams.insert(id, SubstreamState::Closed { buf }); // Notify tasks interested in reading, so they may read the EOF. self.notifier_read.wake_read_stream(id); - }, + } SubstreamState::Open { buf } => { - debug!("{}: Substream {} closed by remote (Open -> RecvClosed)", - self.id, id); - self.substreams.insert(id, SubstreamState::RecvClosed { buf }); + debug!( + "{}: Substream {} closed by remote (Open -> RecvClosed)", + self.id, id + ); + self.substreams + .insert(id, SubstreamState::RecvClosed { buf }); // Notify tasks interested in reading, so they may read the EOF. self.notifier_read.wake_read_stream(id); - }, + } } } else { - trace!("{}: Ignoring `Close` for unknown substream {}. Possibly dropped earlier.", - self.id, id); + trace!( + "{}: Ignoring `Close` for unknown substream {}. Possibly dropped earlier.", + self.id, + id + ); } } @@ -735,11 +819,9 @@ where /// Sends pending frames, without flushing. fn send_pending_frames(&mut self, cx: &mut Context<'_>) -> Poll> { while let Some(frame) = self.pending_frames.pop_back() { - if self.poll_send_frame(cx, || { - frame.clone() - })?.is_pending() { + if self.poll_send_frame(cx, || frame.clone())?.is_pending() { self.pending_frames.push_back(frame); - return Poll::Pending + return Poll::Pending; } } @@ -750,7 +832,7 @@ where fn on_error(&mut self, e: io::Error) -> io::Result { debug!("{}: Multiplexed connection failed: {:?}", self.id, e); self.status = Status::Err(io::Error::new(e.kind(), e.to_string())); - self.pending_frames = Default::default(); + self.pending_frames = Default::default(); self.substreams = Default::default(); self.open_buffer = Default::default(); Err(e) @@ -762,7 +844,7 @@ where match &self.status { Status::Closed => Err(io::Error::new(io::ErrorKind::Other, "Connection is closed")), Status::Err(e) => Err(io::Error::new(e.kind(), e.to_string())), - Status::Open => Ok(()) + Status::Open => Ok(()), } } @@ -770,8 +852,10 @@ where /// has not been reached. fn check_max_pending_frames(&mut self) -> io::Result<()> { if self.pending_frames.len() >= self.config.max_substreams + EXTRA_PENDING_FRAMES { - return self.on_error(io::Error::new(io::ErrorKind::Other, - "Too many pending frames.")); + return self.on_error(io::Error::new( + io::ErrorKind::Other, + "Too many pending frames.", + )); } Ok(()) } @@ -789,19 +873,35 @@ where let state = if let Some(state) = self.substreams.get_mut(&id) { state } else { - trace!("{}: Dropping data {:?} for unknown substream {}", self.id, data, id); - return Ok(()) + trace!( + "{}: Dropping data {:?} for unknown substream {}", + self.id, + data, + id + ); + return Ok(()); }; let buf = if let Some(buf) = state.recv_buf_open() { buf } else { - trace!("{}: Dropping data {:?} for closed or reset substream {}", self.id, data, id); - return Ok(()) + trace!( + "{}: Dropping data {:?} for closed or reset substream {}", + self.id, + data, + id + ); + return Ok(()); }; debug_assert!(buf.len() <= self.config.max_buffer_len); - trace!("{}: Buffering {:?} for stream {} (total: {})", self.id, data, id, buf.len() + 1); + trace!( + "{}: Buffering {:?} for stream {} (total: {})", + self.id, + data, + id, + buf.len() + 1 + ); buf.push(data); self.notifier_read.wake_read_stream(id); if buf.len() > self.config.max_buffer_len { @@ -812,9 +912,8 @@ where self.check_max_pending_frames()?; self.substreams.insert(id, SubstreamState::Reset { buf }); debug!("{}: Pending reset for stream {}", self.id, id); - self.pending_frames.push_front(Frame::Reset { - stream_id: id - }); + self.pending_frames + .push_front(Frame::Reset { stream_id: id }); } MaxBufferBehaviour::Block => { self.blocking_stream = Some(id); @@ -845,7 +944,7 @@ enum SubstreamState { Closed { buf: RecvBuf }, /// The stream has been reset by the local or remote peer but has /// not yet been dropped and may still have buffered frames to read. - Reset { buf: RecvBuf } + Reset { buf: RecvBuf }, } impl SubstreamState { @@ -889,9 +988,11 @@ impl NotifierRead { /// The returned waker should be passed to an I/O read operation /// that schedules a wakeup, if the operation is pending. #[must_use] - fn register_read_stream<'a>(self: &'a Arc, waker: &Waker, id: LocalStreamId) - -> WakerRef<'a> - { + fn register_read_stream<'a>( + self: &'a Arc, + waker: &Waker, + id: LocalStreamId, + ) -> WakerRef<'a> { let mut pending = self.read_stream.lock(); pending.insert(id, waker.clone()); waker_ref(self) @@ -914,7 +1015,7 @@ impl NotifierRead { if let Some(waker) = pending.remove(&id) { waker.wake(); - return true + return true; } false @@ -999,21 +1100,23 @@ const EXTRA_PENDING_FRAMES: usize = 1000; #[cfg(test)] mod tests { + use super::*; use async_std::task; + use asynchronous_codec::{Decoder, Encoder}; use bytes::BytesMut; use futures::prelude::*; - use asynchronous_codec::{Decoder, Encoder}; use quickcheck::*; use rand::prelude::*; use std::collections::HashSet; use std::num::NonZeroU8; use std::ops::DerefMut; use std::pin::Pin; - use super::*; impl Arbitrary for MaxBufferBehaviour { fn arbitrary(g: &mut G) -> MaxBufferBehaviour { - *[MaxBufferBehaviour::Block, MaxBufferBehaviour::ResetStream].choose(g).unwrap() + *[MaxBufferBehaviour::Block, MaxBufferBehaviour::ResetStream] + .choose(g) + .unwrap() } } @@ -1042,10 +1145,10 @@ mod tests { fn poll_read( mut self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &mut [u8] + buf: &mut [u8], ) -> Poll> { if self.eof { - return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())) + return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())); } let n = std::cmp::min(buf.len(), self.r_buf.len()); let data = self.r_buf.split_to(n); @@ -1062,23 +1165,17 @@ mod tests { fn poll_write( mut self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &[u8] + buf: &[u8], ) -> Poll> { self.w_buf.extend_from_slice(buf); Poll::Ready(Ok(buf.len())) } - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_> - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn poll_close( - self: Pin<&mut Self>, - _: &mut Context<'_> - ) -> Poll> { + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } @@ -1092,25 +1189,37 @@ mod tests { let mut codec = Codec::new(); // Open the maximum number of inbound streams. - for i in 0 .. cfg.max_substreams { + for i in 0..cfg.max_substreams { let stream_id = LocalStreamId::dialer(i as u64); - codec.encode(Frame::Open { stream_id }, &mut r_buf).unwrap(); + codec.encode(Frame::Open { stream_id }, &mut r_buf).unwrap(); } // Send more data on stream 0 than the buffer permits. let stream_id = LocalStreamId::dialer(0); let data = Bytes::from("Hello world"); - for _ in 0 .. cfg.max_buffer_len + overflow.get() as usize { - codec.encode(Frame::Data { stream_id, data: data.clone() }, &mut r_buf).unwrap(); + for _ in 0..cfg.max_buffer_len + overflow.get() as usize { + codec + .encode( + Frame::Data { + stream_id, + data: data.clone(), + }, + &mut r_buf, + ) + .unwrap(); } // Setup the multiplexed connection. - let conn = Connection { r_buf, w_buf: BytesMut::new(), eof: false }; + let conn = Connection { + r_buf, + w_buf: BytesMut::new(), + eof: false, + }; let mut m = Multiplexed::new(conn, cfg.clone()); task::block_on(future::poll_fn(move |cx| { // Receive all inbound streams. - for i in 0 .. cfg.max_substreams { + for i in 0..cfg.max_substreams { match m.poll_next_stream(cx) { Poll::Pending => panic!("Expected new inbound stream."), Poll::Ready(Err(e)) => panic!("{:?}", e), @@ -1161,7 +1270,7 @@ mod tests { } MaxBufferBehaviour::Block => { assert!(m.poll_next_stream(cx).is_pending()); - for i in 1 .. cfg.max_substreams { + for i in 1..cfg.max_substreams { let id = LocalStreamId::listener(i as u64); assert!(m.poll_read_stream(cx, id).is_pending()); } @@ -1169,12 +1278,12 @@ mod tests { } // Drain the buffer by reading from the stream. - for _ in 0 .. cfg.max_buffer_len + 1 { + for _ in 0..cfg.max_buffer_len + 1 { match m.poll_read_stream(cx, id) { Poll::Ready(Ok(Some(bytes))) => { assert_eq!(bytes, data); } - x => panic!("Unexpected: {:?}", x) + x => panic!("Unexpected: {:?}", x), } } @@ -1185,8 +1294,8 @@ mod tests { MaxBufferBehaviour::ResetStream => { // Expect to read EOF match m.poll_read_stream(cx, id) { - Poll::Ready(Ok(None)) => {}, - poll => panic!("Unexpected: {:?}", poll) + Poll::Ready(Ok(None)) => {} + poll => panic!("Unexpected: {:?}", poll), } } MaxBufferBehaviour::Block => { @@ -1194,7 +1303,7 @@ mod tests { match m.poll_read_stream(cx, id) { Poll::Ready(Ok(Some(bytes))) => assert_eq!(bytes, data), Poll::Pending => assert_eq!(overflow.get(), 1), - poll => panic!("Unexpected: {:?}", poll) + poll => panic!("Unexpected: {:?}", poll), } } } @@ -1203,7 +1312,7 @@ mod tests { })); } - quickcheck(prop as fn(_,_)) + quickcheck(prop as fn(_, _)) } #[test] @@ -1217,7 +1326,7 @@ mod tests { let conn = Connection { r_buf: BytesMut::new(), w_buf: BytesMut::new(), - eof: false + eof: false, }; let mut m = Multiplexed::new(conn, cfg.clone()); @@ -1225,7 +1334,7 @@ mod tests { let mut opened = HashSet::new(); task::block_on(future::poll_fn(move |cx| { // Open a number of streams. - for _ in 0 .. num_streams { + for _ in 0..num_streams { let id = ready!(m.poll_open_stream(cx)).unwrap(); assert!(opened.insert(id)); assert!(m.poll_read_stream(cx, id).is_pending()); @@ -1238,7 +1347,7 @@ mod tests { // should be closed due to the failed connection. assert!(opened.iter().all(|id| match m.poll_read_stream(cx, *id) { Poll::Ready(Err(e)) => e.kind() == io::ErrorKind::UnexpectedEof, - _ => false + _ => false, })); assert!(m.substreams.is_empty()); @@ -1247,6 +1356,6 @@ mod tests { })) } - quickcheck(prop as fn(_,_)) + quickcheck(prop as fn(_, _)) } } diff --git a/muxers/mplex/src/lib.rs b/muxers/mplex/src/lib.rs index 653c3310ab4..05e7571cf87 100644 --- a/muxers/mplex/src/lib.rs +++ b/muxers/mplex/src/lib.rs @@ -22,18 +22,18 @@ mod codec; mod config; mod io; -pub use config::{MplexConfig, MaxBufferBehaviour}; +pub use config::{MaxBufferBehaviour, MplexConfig}; -use codec::LocalStreamId; -use std::{cmp, iter, task::Context, task::Poll}; use bytes::Bytes; +use codec::LocalStreamId; +use futures::{future, prelude::*, ready}; use libp2p_core::{ - StreamMuxer, muxing::StreamMuxerEvent, upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}, + StreamMuxer, }; use parking_lot::Mutex; -use futures::{prelude::*, future, ready}; +use std::{cmp, iter, task::Context, task::Poll}; impl UpgradeInfo for MplexConfig { type Info = &'static [u8]; @@ -69,7 +69,7 @@ where fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future { future::ready(Ok(Multiplex { - io: Mutex::new(io::Multiplexed::new(socket, self)) + io: Mutex::new(io::Multiplexed::new(socket, self)), })) } } @@ -79,20 +79,21 @@ where /// This implementation isn't capable of detecting when the underlying socket changes its address, /// and no [`StreamMuxerEvent::AddressChange`] event is ever emitted. pub struct Multiplex { - io: Mutex> + io: Mutex>, } impl StreamMuxer for Multiplex where - C: AsyncRead + AsyncWrite + Unpin + C: AsyncRead + AsyncWrite + Unpin, { type Substream = Substream; type OutboundSubstream = OutboundSubstream; type Error = io::Error; - fn poll_event(&self, cx: &mut Context<'_>) - -> Poll>> - { + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll>> { let stream_id = ready!(self.io.lock().poll_next_stream(cx))?; let stream = Substream::new(stream_id); Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(stream))) @@ -102,9 +103,11 @@ where OutboundSubstream {} } - fn poll_outbound(&self, cx: &mut Context<'_>, _: &mut Self::OutboundSubstream) - -> Poll> - { + fn poll_outbound( + &self, + cx: &mut Context<'_>, + _: &mut Self::OutboundSubstream, + ) -> Poll> { let stream_id = ready!(self.io.lock().poll_open_stream(cx))?; Poll::Ready(Ok(Substream::new(stream_id))) } @@ -113,9 +116,12 @@ where // Nothing to do, since `open_outbound` creates no new local state. } - fn read_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream, buf: &mut [u8]) - -> Poll> - { + fn read_substream( + &self, + cx: &mut Context<'_>, + substream: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { loop { // Try to read from the current (i.e. last received) frame. if !substream.current_data.is_empty() { @@ -126,27 +132,36 @@ where // Read the next data frame from the multiplexed stream. match ready!(self.io.lock().poll_read_stream(cx, substream.id))? { - Some(data) => { substream.current_data = data; } - None => { return Poll::Ready(Ok(0)) } + Some(data) => { + substream.current_data = data; + } + None => return Poll::Ready(Ok(0)), } } } - fn write_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream, buf: &[u8]) - -> Poll> - { + fn write_substream( + &self, + cx: &mut Context<'_>, + substream: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { self.io.lock().poll_write_stream(cx, substream.id, buf) } - fn flush_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream) - -> Poll> - { + fn flush_substream( + &self, + cx: &mut Context<'_>, + substream: &mut Self::Substream, + ) -> Poll> { self.io.lock().poll_flush_stream(cx, substream.id) } - fn shutdown_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream) - -> Poll> - { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + substream: &mut Self::Substream, + ) -> Poll> { self.io.lock().poll_close_stream(cx, substream.id) } @@ -176,6 +191,9 @@ pub struct Substream { impl Substream { fn new(id: LocalStreamId) -> Self { - Self { id, current_data: Bytes::new() } + Self { + id, + current_data: Bytes::new(), + } } } diff --git a/muxers/mplex/tests/async_write.rs b/muxers/mplex/tests/async_write.rs index 1414db14847..d4a1df7c4c5 100644 --- a/muxers/mplex/tests/async_write.rs +++ b/muxers/mplex/tests/async_write.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use futures::{channel::oneshot, prelude::*}; use libp2p_core::{muxing, upgrade, Transport}; use libp2p_tcp::TcpConfig; -use futures::{prelude::*, channel::oneshot}; use std::sync::Arc; #[test] @@ -32,14 +32,16 @@ fn async_write() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let mut listener = transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener.next().await + let addr = listener + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -48,12 +50,19 @@ fn async_write() { tx.send(addr).unwrap(); let client = listener - .next().await + .next() + .await .unwrap() .unwrap() - .into_upgrade().unwrap().0.await.unwrap(); + .into_upgrade() + .unwrap() + .0 + .await + .unwrap(); - let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)).await.unwrap(); + let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)) + .await + .unwrap(); let mut buf = Vec::new(); outbound.read_to_end(&mut buf).await.unwrap(); @@ -62,13 +71,16 @@ fn async_write() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let client = Arc::new(transport.dial(rx.await.unwrap()).unwrap().await.unwrap()); let mut inbound = loop { - if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()).await.unwrap() - .into_inbound_substream() { + if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()) + .await + .unwrap() + .into_inbound_substream() + { break s; } }; diff --git a/muxers/mplex/tests/two_peers.rs b/muxers/mplex/tests/two_peers.rs index 54b939a548a..eb0526f4044 100644 --- a/muxers/mplex/tests/two_peers.rs +++ b/muxers/mplex/tests/two_peers.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use futures::{channel::oneshot, prelude::*}; use libp2p_core::{muxing, upgrade, Transport}; use libp2p_tcp::TcpConfig; -use futures::{channel::oneshot, prelude::*}; use std::sync::Arc; #[test] @@ -32,14 +32,16 @@ fn client_to_server_outbound() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let mut listener = transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener.next().await + let addr = listener + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -48,12 +50,19 @@ fn client_to_server_outbound() { tx.send(addr).unwrap(); let client = listener - .next().await + .next() + .await .unwrap() .unwrap() - .into_upgrade().unwrap().0.await.unwrap(); + .into_upgrade() + .unwrap() + .0 + .await + .unwrap(); - let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)).await.unwrap(); + let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)) + .await + .unwrap(); let mut buf = Vec::new(); outbound.read_to_end(&mut buf).await.unwrap(); @@ -62,13 +71,16 @@ fn client_to_server_outbound() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let client = Arc::new(transport.dial(rx.await.unwrap()).unwrap().await.unwrap()); let mut inbound = loop { - if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()).await.unwrap() - .into_inbound_substream() { + if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()) + .await + .unwrap() + .into_inbound_substream() + { break s; } }; @@ -88,14 +100,16 @@ fn client_to_server_inbound() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let mut listener = transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener.next().await + let addr = listener + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -103,15 +117,25 @@ fn client_to_server_inbound() { tx.send(addr).unwrap(); - let client = Arc::new(listener - .next().await - .unwrap() - .unwrap() - .into_upgrade().unwrap().0.await.unwrap()); + let client = Arc::new( + listener + .next() + .await + .unwrap() + .unwrap() + .into_upgrade() + .unwrap() + .0 + .await + .unwrap(), + ); let mut inbound = loop { - if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()).await.unwrap() - .into_inbound_substream() { + if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()) + .await + .unwrap() + .into_inbound_substream() + { break s; } }; @@ -123,11 +147,13 @@ fn client_to_server_inbound() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); - let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)).await.unwrap(); + let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)) + .await + .unwrap(); outbound.write_all(b"hello world").await.unwrap(); outbound.close().await.unwrap(); diff --git a/muxers/yamux/src/lib.rs b/muxers/yamux/src/lib.rs index eb47ad8d0ce..941e0fefd8e 100644 --- a/muxers/yamux/src/lib.rs +++ b/muxers/yamux/src/lib.rs @@ -21,11 +21,20 @@ //! Implements the Yamux multiplexing protocol for libp2p, see also the //! [specification](https://github.com/hashicorp/yamux/blob/master/spec.md). -use futures::{future, prelude::*, ready, stream::{BoxStream, LocalBoxStream}}; +use futures::{ + future, + prelude::*, + ready, + stream::{BoxStream, LocalBoxStream}, +}; use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent}; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use parking_lot::Mutex; -use std::{fmt, io, iter, pin::Pin, task::{Context, Poll}}; +use std::{ + fmt, io, iter, + pin::Pin, + task::{Context, Poll}, +}; use thiserror::Error; /// A Yamux connection. @@ -50,7 +59,7 @@ pub struct OpenSubstreamToken(()); impl Yamux> where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { /// Create a new Yamux connection. fn new(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self { @@ -59,7 +68,7 @@ where let inner = Inner { incoming: Incoming { stream: yamux::into_stream(conn).err_into().boxed(), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, }, control: ctrl, }; @@ -69,7 +78,7 @@ where impl Yamux> where - C: AsyncRead + AsyncWrite + Unpin + 'static + C: AsyncRead + AsyncWrite + Unpin + 'static, { /// Create a new Yamux connection (which is ![`Send`]). fn local(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self { @@ -78,7 +87,7 @@ where let inner = Inner { incoming: LocalIncoming { stream: yamux::into_stream(conn).err_into().boxed_local(), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, }, control: ctrl, }; @@ -91,20 +100,21 @@ pub type YamuxResult = Result; /// > **Note**: This implementation never emits [`StreamMuxerEvent::AddressChange`] events. impl StreamMuxer for Yamux where - S: Stream> + Unpin + S: Stream> + Unpin, { type Substream = yamux::Stream; type OutboundSubstream = OpenSubstreamToken; type Error = YamuxError; - fn poll_event(&self, c: &mut Context<'_>) - -> Poll>> - { + fn poll_event( + &self, + c: &mut Context<'_>, + ) -> Poll>> { let mut inner = self.0.lock(); match ready!(inner.incoming.poll_next_unpin(c)) { Some(Ok(s)) => Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(s))), Some(Err(e)) => Poll::Ready(Err(e)), - None => Poll::Ready(Err(yamux::ConnectionError::Closed.into())) + None => Poll::Ready(Err(yamux::ConnectionError::Closed.into())), } } @@ -112,53 +122,71 @@ where OpenSubstreamToken(()) } - fn poll_outbound(&self, c: &mut Context<'_>, _: &mut OpenSubstreamToken) - -> Poll> - { + fn poll_outbound( + &self, + c: &mut Context<'_>, + _: &mut OpenSubstreamToken, + ) -> Poll> { let mut inner = self.0.lock(); - Pin::new(&mut inner.control).poll_open_stream(c).map_err(YamuxError) + Pin::new(&mut inner.control) + .poll_open_stream(c) + .map_err(YamuxError) } fn destroy_outbound(&self, _: Self::OutboundSubstream) { self.0.lock().control.abort_open_stream() } - fn read_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream, b: &mut [u8]) - -> Poll> - { - Pin::new(s).poll_read(c, b).map_err(|e| YamuxError(e.into())) - } - - fn write_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream, b: &[u8]) - -> Poll> - { - Pin::new(s).poll_write(c, b).map_err(|e| YamuxError(e.into())) - } - - fn flush_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream) - -> Poll> - { + fn read_substream( + &self, + c: &mut Context<'_>, + s: &mut Self::Substream, + b: &mut [u8], + ) -> Poll> { + Pin::new(s) + .poll_read(c, b) + .map_err(|e| YamuxError(e.into())) + } + + fn write_substream( + &self, + c: &mut Context<'_>, + s: &mut Self::Substream, + b: &[u8], + ) -> Poll> { + Pin::new(s) + .poll_write(c, b) + .map_err(|e| YamuxError(e.into())) + } + + fn flush_substream( + &self, + c: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { Pin::new(s).poll_flush(c).map_err(|e| YamuxError(e.into())) } - fn shutdown_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream) - -> Poll> - { + fn shutdown_substream( + &self, + c: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { Pin::new(s).poll_close(c).map_err(|e| YamuxError(e.into())) } - fn destroy_substream(&self, _: Self::Substream) { } + fn destroy_substream(&self, _: Self::Substream) {} fn close(&self, c: &mut Context<'_>) -> Poll> { let mut inner = self.0.lock(); if let std::task::Poll::Ready(x) = Pin::new(&mut inner.control).poll_close(c) { - return Poll::Ready(x.map_err(YamuxError)) + return Poll::Ready(x.map_err(YamuxError)); } while let std::task::Poll::Ready(x) = inner.incoming.poll_next_unpin(c) { match x { - Some(Ok(_)) => {} // drop inbound stream + Some(Ok(_)) => {} // drop inbound stream Some(Err(e)) => return Poll::Ready(Err(e)), - None => return Poll::Ready(Ok(())) + None => return Poll::Ready(Ok(())), } } Poll::Pending @@ -173,7 +201,7 @@ where #[derive(Clone)] pub struct YamuxConfig { inner: yamux::Config, - mode: Option + mode: Option, } /// The window update mode determines when window updates are @@ -299,7 +327,7 @@ impl UpgradeInfo for YamuxLocalConfig { impl InboundUpgrade for YamuxConfig where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = Yamux>; type Error = io::Error; @@ -313,7 +341,7 @@ where impl InboundUpgrade for YamuxLocalConfig where - C: AsyncRead + AsyncWrite + Unpin + 'static + C: AsyncRead + AsyncWrite + Unpin + 'static, { type Output = Yamux>; type Error = io::Error; @@ -328,7 +356,7 @@ where impl OutboundUpgrade for YamuxConfig where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = Yamux>; type Error = io::Error; @@ -342,7 +370,7 @@ where impl OutboundUpgrade for YamuxLocalConfig where - C: AsyncRead + AsyncWrite + Unpin + 'static + C: AsyncRead + AsyncWrite + Unpin + 'static, { type Output = Yamux>; type Error = io::Error; @@ -364,7 +392,7 @@ impl From for io::Error { fn from(err: YamuxError) -> Self { match err.0 { yamux::ConnectionError::Io(e) => e, - e => io::Error::new(io::ErrorKind::Other, e) + e => io::Error::new(io::ErrorKind::Other, e), } } } @@ -372,7 +400,7 @@ impl From for io::Error { /// The [`futures::stream::Stream`] of incoming substreams. pub struct Incoming { stream: BoxStream<'static, Result>, - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } impl fmt::Debug for Incoming { @@ -384,7 +412,7 @@ impl fmt::Debug for Incoming { /// The [`futures::stream::Stream`] of incoming substreams (`!Send`). pub struct LocalIncoming { stream: LocalBoxStream<'static, Result>, - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } impl fmt::Debug for LocalIncoming { @@ -396,7 +424,10 @@ impl fmt::Debug for LocalIncoming { impl Stream for Incoming { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> std::task::Poll> { self.stream.as_mut().poll_next_unpin(cx) } @@ -405,13 +436,15 @@ impl Stream for Incoming { } } -impl Unpin for Incoming { -} +impl Unpin for Incoming {} impl Stream for LocalIncoming { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> std::task::Poll> { self.stream.as_mut().poll_next_unpin(cx) } @@ -420,5 +453,4 @@ impl Stream for LocalIncoming { } } -impl Unpin for LocalIncoming { -} +impl Unpin for LocalIncoming {} diff --git a/protocols/floodsub/build.rs b/protocols/floodsub/build.rs index 3de5b750ca2..a3de99880dc 100644 --- a/protocols/floodsub/build.rs +++ b/protocols/floodsub/build.rs @@ -19,6 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/rpc.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/rpc.proto"], &["src"]).unwrap(); } - diff --git a/protocols/floodsub/src/layer.rs b/protocols/floodsub/src/layer.rs index b2916fa0605..eb5a7cb30b2 100644 --- a/protocols/floodsub/src/layer.rs +++ b/protocols/floodsub/src/layer.rs @@ -18,26 +18,24 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::protocol::{FloodsubProtocol, FloodsubMessage, FloodsubRpc, FloodsubSubscription, FloodsubSubscriptionAction}; +use crate::protocol::{ + FloodsubMessage, FloodsubProtocol, FloodsubRpc, FloodsubSubscription, + FloodsubSubscriptionAction, +}; use crate::topic::Topic; use crate::FloodsubConfig; use cuckoofilter::{CuckooError, CuckooFilter}; use fnv::FnvHashSet; -use libp2p_core::{PeerId, connection::ConnectionId}; +use libp2p_core::{connection::ConnectionId, PeerId}; use libp2p_swarm::{ - NetworkBehaviour, - NetworkBehaviourAction, - PollParameters, - ProtocolsHandler, - OneShotHandler, - NotifyHandler, - DialPeerCondition, + DialPeerCondition, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, OneShotHandler, + PollParameters, ProtocolsHandler, }; use log::warn; use smallvec::SmallVec; -use std::{collections::VecDeque, iter}; use std::collections::hash_map::{DefaultHasher, HashMap}; use std::task::{Context, Poll}; +use std::{collections::VecDeque, iter}; /// Network behaviour that handles the floodsub protocol. pub struct Floodsub { @@ -87,23 +85,25 @@ impl Floodsub { // Send our topics to this node if we're already connected to it. if self.connected_peers.contains_key(&peer_id) { for topic in self.subscribed_topics.iter().cloned() { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::Any, - event: FloodsubRpc { - messages: Vec::new(), - subscriptions: vec![FloodsubSubscription { - topic, - action: FloodsubSubscriptionAction::Subscribe, - }], - }, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id, + handler: NotifyHandler::Any, + event: FloodsubRpc { + messages: Vec::new(), + subscriptions: vec![FloodsubSubscription { + topic, + action: FloodsubSubscriptionAction::Subscribe, + }], + }, + }); } } if self.target_peers.insert(peer_id) { self.events.push_back(NetworkBehaviourAction::DialPeer { - peer_id, condition: DialPeerCondition::Disconnected + peer_id, + condition: DialPeerCondition::Disconnected, }); } } @@ -123,17 +123,18 @@ impl Floodsub { } for peer in self.connected_peers.keys() { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *peer, - handler: NotifyHandler::Any, - event: FloodsubRpc { - messages: Vec::new(), - subscriptions: vec![FloodsubSubscription { - topic: topic.clone(), - action: FloodsubSubscriptionAction::Subscribe, - }], - }, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *peer, + handler: NotifyHandler::Any, + event: FloodsubRpc { + messages: Vec::new(), + subscriptions: vec![FloodsubSubscription { + topic: topic.clone(), + action: FloodsubSubscriptionAction::Subscribe, + }], + }, + }); } self.subscribed_topics.push(topic); @@ -148,23 +149,24 @@ impl Floodsub { pub fn unsubscribe(&mut self, topic: Topic) -> bool { let pos = match self.subscribed_topics.iter().position(|t| *t == topic) { Some(pos) => pos, - None => return false + None => return false, }; self.subscribed_topics.remove(pos); for peer in self.connected_peers.keys() { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *peer, - handler: NotifyHandler::Any, - event: FloodsubRpc { - messages: Vec::new(), - subscriptions: vec![FloodsubSubscription { - topic: topic.clone(), - action: FloodsubSubscriptionAction::Unsubscribe, - }], - }, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *peer, + handler: NotifyHandler::Any, + event: FloodsubRpc { + messages: Vec::new(), + subscriptions: vec![FloodsubSubscription { + topic: topic.clone(), + action: FloodsubSubscriptionAction::Unsubscribe, + }], + }, + }); } true @@ -184,16 +186,29 @@ impl Floodsub { /// /// /// > **Note**: Doesn't do anything if we're not subscribed to any of the topics. - pub fn publish_many(&mut self, topic: impl IntoIterator>, data: impl Into>) { + pub fn publish_many( + &mut self, + topic: impl IntoIterator>, + data: impl Into>, + ) { self.publish_many_inner(topic, data, true) } /// Publishes a message with multiple topics to the network, even if we're not subscribed to any of the topics. - pub fn publish_many_any(&mut self, topic: impl IntoIterator>, data: impl Into>) { + pub fn publish_many_any( + &mut self, + topic: impl IntoIterator>, + data: impl Into>, + ) { self.publish_many_inner(topic, data, false) } - fn publish_many_inner(&mut self, topic: impl IntoIterator>, data: impl Into>, check_self_subscriptions: bool) { + fn publish_many_inner( + &mut self, + topic: impl IntoIterator>, + data: impl Into>, + check_self_subscriptions: bool, + ) { let message = FloodsubMessage { source: self.config.local_peer_id, data: data.into(), @@ -204,39 +219,48 @@ impl Floodsub { topics: topic.into_iter().map(Into::into).collect(), }; - let self_subscribed = self.subscribed_topics.iter().any(|t| message.topics.iter().any(|u| t == u)); + let self_subscribed = self + .subscribed_topics + .iter() + .any(|t| message.topics.iter().any(|u| t == u)); if self_subscribed { if let Err(e @ CuckooError::NotEnoughSpace) = self.received.add(&message) { warn!( "Message was added to 'received' Cuckoofilter but some \ - other message was removed as a consequence: {}", e, + other message was removed as a consequence: {}", + e, ); } if self.config.subscribe_local_messages { - self.events.push_back( - NetworkBehaviourAction::GenerateEvent(FloodsubEvent::Message(message.clone()))); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + FloodsubEvent::Message(message.clone()), + )); } } // Don't publish the message if we have to check subscriptions // and we're not subscribed ourselves to any of the topics. if check_self_subscriptions && !self_subscribed { - return + return; } // Send to peers we know are subscribed to the topic. for (peer_id, sub_topic) in self.connected_peers.iter() { - if !sub_topic.iter().any(|t| message.topics.iter().any(|u| t == u)) { + if !sub_topic + .iter() + .any(|t| message.topics.iter().any(|u| t == u)) + { continue; } - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *peer_id, - handler: NotifyHandler::Any, - event: FloodsubRpc { - subscriptions: Vec::new(), - messages: vec![message.clone()], - } - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *peer_id, + handler: NotifyHandler::Any, + event: FloodsubRpc { + subscriptions: Vec::new(), + messages: vec![message.clone()], + }, + }); } } } @@ -253,17 +277,18 @@ impl NetworkBehaviour for Floodsub { // We need to send our subscriptions to the newly-connected node. if self.target_peers.contains(id) { for topic in self.subscribed_topics.iter().cloned() { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *id, - handler: NotifyHandler::Any, - event: FloodsubRpc { - messages: Vec::new(), - subscriptions: vec![FloodsubSubscription { - topic, - action: FloodsubSubscriptionAction::Subscribe, - }], - }, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *id, + handler: NotifyHandler::Any, + event: FloodsubRpc { + messages: Vec::new(), + subscriptions: vec![FloodsubSubscription { + topic, + action: FloodsubSubscriptionAction::Subscribe, + }], + }, + }); } } @@ -279,7 +304,7 @@ impl NetworkBehaviour for Floodsub { if self.target_peers.contains(id) { self.events.push_back(NetworkBehaviourAction::DialPeer { peer_id: *id, - condition: DialPeerCondition::Disconnected + condition: DialPeerCondition::Disconnected, }); } } @@ -306,19 +331,26 @@ impl NetworkBehaviour for Floodsub { if !remote_peer_topics.contains(&subscription.topic) { remote_peer_topics.push(subscription.topic.clone()); } - self.events.push_back(NetworkBehaviourAction::GenerateEvent(FloodsubEvent::Subscribed { - peer_id: propagation_source, - topic: subscription.topic, - })); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + FloodsubEvent::Subscribed { + peer_id: propagation_source, + topic: subscription.topic, + }, + )); } FloodsubSubscriptionAction::Unsubscribe => { - if let Some(pos) = remote_peer_topics.iter().position(|t| t == &subscription.topic ) { + if let Some(pos) = remote_peer_topics + .iter() + .position(|t| t == &subscription.topic) + { remote_peer_topics.remove(pos); } - self.events.push_back(NetworkBehaviourAction::GenerateEvent(FloodsubEvent::Unsubscribed { - peer_id: propagation_source, - topic: subscription.topic, - })); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + FloodsubEvent::Unsubscribed { + peer_id: propagation_source, + topic: subscription.topic, + }, + )); } } } @@ -330,20 +362,27 @@ impl NetworkBehaviour for Floodsub { // Use `self.received` to skip the messages that we have already received in the past. // Note that this can result in false positives. match self.received.test_and_add(&message) { - Ok(true) => {}, // Message was added. + Ok(true) => {} // Message was added. Ok(false) => continue, // Message already existed. - Err(e @ CuckooError::NotEnoughSpace) => { // Message added, but some other removed. + Err(e @ CuckooError::NotEnoughSpace) => { + // Message added, but some other removed. warn!( "Message was added to 'received' Cuckoofilter but some \ - other message was removed as a consequence: {}", e, + other message was removed as a consequence: {}", + e, ); } } // Add the message to be dispatched to the user. - if self.subscribed_topics.iter().any(|t| message.topics.iter().any(|u| t == u)) { + if self + .subscribed_topics + .iter() + .any(|t| message.topics.iter().any(|u| t == u)) + { let event = FloodsubEvent::Message(message.clone()); - self.events.push_back(NetworkBehaviourAction::GenerateEvent(event)); + self.events + .push_back(NetworkBehaviourAction::GenerateEvent(event)); } // Propagate the message to everyone else who is subscribed to any of the topics. @@ -352,27 +391,34 @@ impl NetworkBehaviour for Floodsub { continue; } - if !subscr_topics.iter().any(|t| message.topics.iter().any(|u| t == u)) { + if !subscr_topics + .iter() + .any(|t| message.topics.iter().any(|u| t == u)) + { continue; } if let Some(pos) = rpcs_to_dispatch.iter().position(|(p, _)| p == peer_id) { rpcs_to_dispatch[pos].1.messages.push(message.clone()); } else { - rpcs_to_dispatch.push((*peer_id, FloodsubRpc { - subscriptions: Vec::new(), - messages: vec![message.clone()], - })); + rpcs_to_dispatch.push(( + *peer_id, + FloodsubRpc { + subscriptions: Vec::new(), + messages: vec![message.clone()], + }, + )); } } } for (peer_id, rpc) in rpcs_to_dispatch { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::Any, - event: rpc, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id, + handler: NotifyHandler::Any, + event: rpc, + }); } } diff --git a/protocols/floodsub/src/lib.rs b/protocols/floodsub/src/lib.rs index 8e7014bedaa..5481ddbf43d 100644 --- a/protocols/floodsub/src/lib.rs +++ b/protocols/floodsub/src/lib.rs @@ -50,7 +50,7 @@ impl FloodsubConfig { pub fn new(local_peer_id: PeerId) -> Self { Self { local_peer_id, - subscribe_local_messages: false + subscribe_local_messages: false, } } } diff --git a/protocols/floodsub/src/protocol.rs b/protocols/floodsub/src/protocol.rs index 1b942549b22..df694b2e06d 100644 --- a/protocols/floodsub/src/protocol.rs +++ b/protocols/floodsub/src/protocol.rs @@ -20,10 +20,13 @@ use crate::rpc_proto; use crate::topic::Topic; -use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, PeerId, upgrade}; +use futures::{ + io::{AsyncRead, AsyncWrite}, + AsyncWriteExt, Future, +}; +use libp2p_core::{upgrade, InboundUpgrade, OutboundUpgrade, PeerId, UpgradeInfo}; use prost::Message; use std::{error, fmt, io, iter, pin::Pin}; -use futures::{Future, io::{AsyncRead, AsyncWrite}, AsyncWriteExt}; /// Implementation of `ConnectionUpgrade` for the floodsub protocol. #[derive(Debug, Clone, Default)] @@ -61,21 +64,18 @@ where let mut messages = Vec::with_capacity(rpc.publish.len()); for publish in rpc.publish.into_iter() { messages.push(FloodsubMessage { - source: PeerId::from_bytes(&publish.from.unwrap_or_default()).map_err(|_| { - FloodsubDecodeError::InvalidPeerId - })?, + source: PeerId::from_bytes(&publish.from.unwrap_or_default()) + .map_err(|_| FloodsubDecodeError::InvalidPeerId)?, data: publish.data.unwrap_or_default(), sequence_number: publish.seqno.unwrap_or_default(), - topics: publish.topic_ids - .into_iter() - .map(Topic::new) - .collect(), + topics: publish.topic_ids.into_iter().map(Topic::new).collect(), }); } Ok(FloodsubRpc { messages, - subscriptions: rpc.subscriptions + subscriptions: rpc + .subscriptions .into_iter() .map(|sub| FloodsubSubscription { action: if Some(true) == sub.subscribe { @@ -117,12 +117,15 @@ impl From for FloodsubDecodeError { impl fmt::Display for FloodsubDecodeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - FloodsubDecodeError::ReadError(ref err) => - write!(f, "Error while reading from socket: {}", err), - FloodsubDecodeError::ProtobufError(ref err) => - write!(f, "Error while decoding protobuf: {}", err), - FloodsubDecodeError::InvalidPeerId => - write!(f, "Error while decoding PeerId from message"), + FloodsubDecodeError::ReadError(ref err) => { + write!(f, "Error while reading from socket: {}", err) + } + FloodsubDecodeError::ProtobufError(ref err) => { + write!(f, "Error while decoding protobuf: {}", err) + } + FloodsubDecodeError::InvalidPeerId => { + write!(f, "Error while decoding PeerId from message") + } } } } @@ -179,32 +182,30 @@ impl FloodsubRpc { /// Turns this `FloodsubRpc` into a message that can be sent to a substream. fn into_bytes(self) -> Vec { let rpc = rpc_proto::Rpc { - publish: self.messages.into_iter() - .map(|msg| { - rpc_proto::Message { - from: Some(msg.source.to_bytes()), - data: Some(msg.data), - seqno: Some(msg.sequence_number), - topic_ids: msg.topics - .into_iter() - .map(|topic| topic.into()) - .collect() - } + publish: self + .messages + .into_iter() + .map(|msg| rpc_proto::Message { + from: Some(msg.source.to_bytes()), + data: Some(msg.data), + seqno: Some(msg.sequence_number), + topic_ids: msg.topics.into_iter().map(|topic| topic.into()).collect(), }) .collect(), - subscriptions: self.subscriptions.into_iter() - .map(|topic| { - rpc_proto::rpc::SubOpts { - subscribe: Some(topic.action == FloodsubSubscriptionAction::Subscribe), - topic_id: Some(topic.topic.into()) - } + subscriptions: self + .subscriptions + .into_iter() + .map(|topic| rpc_proto::rpc::SubOpts { + subscribe: Some(topic.action == FloodsubSubscriptionAction::Subscribe), + topic_id: Some(topic.topic.into()), }) - .collect() + .collect(), }; let mut buf = Vec::with_capacity(rpc.encoded_len()); - rpc.encode(&mut buf).expect("Vec provides capacity as needed"); + rpc.encode(&mut buf) + .expect("Vec provides capacity as needed"); buf } } diff --git a/protocols/gossipsub/src/behaviour.rs b/protocols/gossipsub/src/behaviour.rs index 8a9c1b9efe0..a4f4ec24bfc 100644 --- a/protocols/gossipsub/src/behaviour.rs +++ b/protocols/gossipsub/src/behaviour.rs @@ -3198,9 +3198,13 @@ where NetworkBehaviourAction::ReportObservedAddr { address, score } => { NetworkBehaviourAction::ReportObservedAddr { address, score } } - NetworkBehaviourAction::CloseConnection { peer_id, connection } => { - NetworkBehaviourAction::CloseConnection { peer_id, connection } - } + NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + } => NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }, }); } diff --git a/protocols/gossipsub/src/protocol.rs b/protocols/gossipsub/src/protocol.rs index 19293f58d7e..199d210452a 100644 --- a/protocols/gossipsub/src/protocol.rs +++ b/protocols/gossipsub/src/protocol.rs @@ -27,12 +27,12 @@ use crate::types::{ GossipsubControlAction, GossipsubRpc, GossipsubSubscription, GossipsubSubscriptionAction, MessageId, PeerInfo, PeerKind, RawGossipsubMessage, }; +use asynchronous_codec::{Decoder, Encoder, Framed}; use byteorder::{BigEndian, ByteOrder}; use bytes::Bytes; use bytes::BytesMut; use futures::future; use futures::prelude::*; -use asynchronous_codec::{Decoder, Encoder, Framed}; use libp2p_core::{ identity::PublicKey, InboundUpgrade, OutboundUpgrade, PeerId, ProtocolName, UpgradeInfo, }; diff --git a/protocols/gossipsub/tests/smoke.rs b/protocols/gossipsub/tests/smoke.rs index 841929a7f0c..3cf3f882427 100644 --- a/protocols/gossipsub/tests/smoke.rs +++ b/protocols/gossipsub/tests/smoke.rs @@ -51,11 +51,13 @@ impl Future for Graph { for (addr, node) in &mut self.nodes { loop { match node.poll_next_unpin(cx) { - Poll::Ready(Some(SwarmEvent::Behaviour(event))) => return Poll::Ready((addr.clone(), event)), + Poll::Ready(Some(SwarmEvent::Behaviour(event))) => { + return Poll::Ready((addr.clone(), event)) + } Poll::Ready(Some(_)) => {} Poll::Ready(None) => panic!("unexpected None when polling nodes"), Poll::Pending => break, - } + } } } @@ -226,7 +228,11 @@ fn multi_hop_propagation() { graph = graph.drain_poll(); // Publish a single message. - graph.nodes[0].1.behaviour_mut().publish(topic, vec![1, 2, 3]).unwrap(); + graph.nodes[0] + .1 + .behaviour_mut() + .publish(topic, vec![1, 2, 3]) + .unwrap(); // Wait for all nodes to receive the published message. let mut received_msgs = 0; diff --git a/protocols/identify/build.rs b/protocols/identify/build.rs index 1b0feff6a40..56c7b20121a 100644 --- a/protocols/identify/build.rs +++ b/protocols/identify/build.rs @@ -19,6 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/structs.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/structs.proto"], &["src"]).unwrap(); } - diff --git a/protocols/identify/src/handler.rs b/protocols/identify/src/handler.rs index 11c239cdfab..f0d05f79dc9 100644 --- a/protocols/identify/src/handler.rs +++ b/protocols/identify/src/handler.rs @@ -19,32 +19,16 @@ // DEALINGS IN THE SOFTWARE. use crate::protocol::{ - IdentifyProtocol, - IdentifyPushProtocol, - IdentifyInfo, - InboundPush, - OutboundPush, - ReplySubstream + IdentifyInfo, IdentifyProtocol, IdentifyPushProtocol, InboundPush, OutboundPush, ReplySubstream, }; use futures::prelude::*; -use libp2p_core::either::{ - EitherError, - EitherOutput, -}; +use libp2p_core::either::{EitherError, EitherOutput}; use libp2p_core::upgrade::{ - EitherUpgrade, - InboundUpgrade, - OutboundUpgrade, - SelectUpgrade, - UpgradeError, + EitherUpgrade, InboundUpgrade, OutboundUpgrade, SelectUpgrade, UpgradeError, }; use libp2p_swarm::{ - NegotiatedSubstream, - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + KeepAlive, NegotiatedSubstream, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; use smallvec::SmallVec; use std::{io, pin::Pin, task::Context, task::Poll, time::Duration}; @@ -57,12 +41,14 @@ use wasm_timer::Delay; /// permitting the underlying connection to be closed. pub struct IdentifyHandler { /// Pending events to yield. - events: SmallVec<[ProtocolsHandlerEvent< + events: SmallVec< + [ProtocolsHandlerEvent< EitherUpgrade>, (), IdentifyHandlerEvent, io::Error, - >; 4]>, + >; 4], + >, /// Future that fires when we need to identify the node again. next_id: Delay, @@ -114,28 +100,23 @@ impl ProtocolsHandler for IdentifyHandler { fn listen_protocol(&self) -> SubstreamProtocol { SubstreamProtocol::new( - SelectUpgrade::new( - IdentifyProtocol, - IdentifyPushProtocol::inbound(), - ), ()) + SelectUpgrade::new(IdentifyProtocol, IdentifyPushProtocol::inbound()), + (), + ) } fn inject_fully_negotiated_inbound( &mut self, output: >::Output, - _: Self::InboundOpenInfo + _: Self::InboundOpenInfo, ) { match output { - EitherOutput::First(substream) => { - self.events.push( - ProtocolsHandlerEvent::Custom( - IdentifyHandlerEvent::Identify(substream))) - } - EitherOutput::Second(info) => { - self.events.push( - ProtocolsHandlerEvent::Custom( - IdentifyHandlerEvent::Identified(info))) - } + EitherOutput::First(substream) => self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::Identify(substream), + )), + EitherOutput::Second(info) => self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::Identified(info), + )), } } @@ -146,39 +127,42 @@ impl ProtocolsHandler for IdentifyHandler { ) { match output { EitherOutput::First(remote_info) => { - self.events.push( - ProtocolsHandlerEvent::Custom( - IdentifyHandlerEvent::Identified(remote_info))); + self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::Identified(remote_info), + )); self.keep_alive = KeepAlive::No; } - EitherOutput::Second(()) => self.events.push( - ProtocolsHandlerEvent::Custom(IdentifyHandlerEvent::IdentificationPushed)) + EitherOutput::Second(()) => self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::IdentificationPushed, + )), } } fn inject_event(&mut self, IdentifyPush(push): Self::InEvent) { - self.events.push(ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new( - EitherUpgrade::B( - IdentifyPushProtocol::outbound(push)), ()) - }); + self.events + .push(ProtocolsHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new( + EitherUpgrade::B(IdentifyPushProtocol::outbound(push)), + (), + ), + }); } fn inject_dial_upgrade_error( &mut self, _info: Self::OutboundOpenInfo, err: ProtocolsHandlerUpgrErr< - >::Error - > + >::Error, + >, ) { let err = err.map_upgrade_err(|e| match e { UpgradeError::Select(e) => UpgradeError::Select(e), UpgradeError::Apply(EitherError::A(ioe)) => UpgradeError::Apply(ioe), UpgradeError::Apply(EitherError::B(ioe)) => UpgradeError::Apply(ioe), }); - self.events.push( - ProtocolsHandlerEvent::Custom( - IdentifyHandlerEvent::IdentificationError(err))); + self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::IdentificationError(err), + )); self.keep_alive = KeepAlive::No; self.next_id.reset(self.interval); } @@ -187,7 +171,10 @@ impl ProtocolsHandler for IdentifyHandler { self.keep_alive } - fn poll(&mut self, cx: &mut Context<'_>) -> Poll< + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< ProtocolsHandlerEvent< Self::OutboundProtocol, Self::OutboundOpenInfo, @@ -196,9 +183,7 @@ impl ProtocolsHandler for IdentifyHandler { >, > { if !self.events.is_empty() { - return Poll::Ready( - self.events.remove(0), - ); + return Poll::Ready(self.events.remove(0)); } // Poll the future that fires when we need to identify the node again. @@ -207,11 +192,11 @@ impl ProtocolsHandler for IdentifyHandler { Poll::Ready(Ok(())) => { self.next_id.reset(self.interval); let ev = ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(EitherUpgrade::A(IdentifyProtocol), ()) + protocol: SubstreamProtocol::new(EitherUpgrade::A(IdentifyProtocol), ()), }; Poll::Ready(ev) } - Poll::Ready(Err(err)) => Poll::Ready(ProtocolsHandlerEvent::Close(err)) + Poll::Ready(Err(err)) => Poll::Ready(ProtocolsHandlerEvent::Close(err)), } } } diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 0b60cb03fc9..caa737f669c 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -22,26 +22,16 @@ use crate::handler::{IdentifyHandler, IdentifyHandlerEvent, IdentifyPush}; use crate::protocol::{IdentifyInfo, ReplySubstream}; use futures::prelude::*; use libp2p_core::{ - ConnectedPoint, - Multiaddr, - PeerId, - PublicKey, connection::{ConnectionId, ListenerId}, - upgrade::UpgradeError + upgrade::UpgradeError, + ConnectedPoint, Multiaddr, PeerId, PublicKey, }; use libp2p_swarm::{ - AddressScore, - DialPeerCondition, - NegotiatedSubstream, - NetworkBehaviour, - NetworkBehaviourAction, - NotifyHandler, - PollParameters, - ProtocolsHandler, - ProtocolsHandlerUpgrErr + AddressScore, DialPeerCondition, NegotiatedSubstream, NetworkBehaviour, NetworkBehaviourAction, + NotifyHandler, PollParameters, ProtocolsHandler, ProtocolsHandlerUpgrErr, }; use std::{ - collections::{HashSet, HashMap, VecDeque}, + collections::{HashMap, HashSet, VecDeque}, io, pin::Pin, task::Context, @@ -74,13 +64,13 @@ enum Reply { Queued { peer: PeerId, io: ReplySubstream, - observed: Multiaddr + observed: Multiaddr, }, /// The reply is being sent. Sending { peer: PeerId, io: Pin> + Send>>, - } + }, } /// Configuration for the [`Identify`] [`NetworkBehaviour`]. @@ -178,14 +168,14 @@ impl Identify { /// Initiates an active push of the local peer information to the given peers. pub fn push(&mut self, peers: I) where - I: IntoIterator + I: IntoIterator, { for p in peers { if self.pending_push.insert(p) { if !self.connected.contains_key(&p) { self.events.push_back(NetworkBehaviourAction::DialPeer { peer_id: p, - condition: DialPeerCondition::Disconnected + condition: DialPeerCondition::Disconnected, }); } } @@ -201,16 +191,29 @@ impl NetworkBehaviour for Identify { IdentifyHandler::new(self.config.initial_delay, self.config.interval) } - fn inject_connection_established(&mut self, peer_id: &PeerId, conn: &ConnectionId, endpoint: &ConnectedPoint) { + fn inject_connection_established( + &mut self, + peer_id: &PeerId, + conn: &ConnectionId, + endpoint: &ConnectedPoint, + ) { let addr = match endpoint { ConnectedPoint::Dialer { address } => address.clone(), ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(), }; - self.connected.entry(*peer_id).or_default().insert(*conn, addr); + self.connected + .entry(*peer_id) + .or_default() + .insert(*conn, addr); } - fn inject_connection_closed(&mut self, peer_id: &PeerId, conn: &ConnectionId, _: &ConnectedPoint) { + fn inject_connection_closed( + &mut self, + peer_id: &PeerId, + conn: &ConnectionId, + _: &ConnectedPoint, + ) { if let Some(addrs) = self.connected.get_mut(peer_id) { addrs.remove(conn); } @@ -248,41 +251,39 @@ impl NetworkBehaviour for Identify { match event { IdentifyHandlerEvent::Identified(info) => { let observed = info.observed_addr.clone(); - self.events.push_back( - NetworkBehaviourAction::GenerateEvent( - IdentifyEvent::Received { - peer_id, - info, - })); - self.events.push_back( - NetworkBehaviourAction::ReportObservedAddr { + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + IdentifyEvent::Received { peer_id, info }, + )); + self.events + .push_back(NetworkBehaviourAction::ReportObservedAddr { address: observed, score: AddressScore::Finite(1), }); } IdentifyHandlerEvent::IdentificationPushed => { - self.events.push_back( - NetworkBehaviourAction::GenerateEvent( - IdentifyEvent::Pushed { - peer_id, - })); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + IdentifyEvent::Pushed { peer_id }, + )); } IdentifyHandlerEvent::Identify(sender) => { - let observed = self.connected.get(&peer_id) + let observed = self + .connected + .get(&peer_id) .and_then(|addrs| addrs.get(&connection)) - .expect("`inject_event` is only called with an established connection \ - and `inject_connection_established` ensures there is an entry; qed"); - self.pending_replies.push_back( - Reply::Queued { - peer: peer_id, - io: sender, - observed: observed.clone() - }); + .expect( + "`inject_event` is only called with an established connection \ + and `inject_connection_established` ensures there is an entry; qed", + ); + self.pending_replies.push_back(Reply::Queued { + peer: peer_id, + io: sender, + observed: observed.clone(), + }); } IdentifyHandlerEvent::IdentificationError(error) => { - self.events.push_back( - NetworkBehaviourAction::GenerateEvent( - IdentifyEvent::Error { peer_id, error })); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + IdentifyEvent::Error { peer_id, error }, + )); } } } @@ -332,7 +333,7 @@ impl NetworkBehaviour for Identify { peer_id, event: push, handler: NotifyHandler::Any, - }) + }); } // Check for pending replies to send. @@ -360,12 +361,12 @@ impl NetworkBehaviour for Identify { Poll::Ready(Ok(())) => { let event = IdentifyEvent::Sent { peer_id: peer }; return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); - }, + } Poll::Pending => { self.pending_replies.push_back(Reply::Sending { peer, io }); if sending == to_send { // All remaining futures are NotReady - break + break; } else { reply = self.pending_replies.pop_front(); } @@ -373,13 +374,15 @@ impl NetworkBehaviour for Identify { Poll::Ready(Err(err)) => { let event = IdentifyEvent::Error { peer_id: peer, - error: ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)) + error: ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply( + err, + )), }; return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); - }, + } } } - None => unreachable!() + None => unreachable!(), } } } @@ -438,22 +441,20 @@ fn listen_addrs(params: &impl PollParameters) -> Vec { mod tests { use super::*; use futures::pin_mut; - use libp2p_core::{ - identity, - PeerId, - muxing::StreamMuxerBox, - transport, - Transport, - upgrade - }; + use libp2p_core::{identity, muxing::StreamMuxerBox, transport, upgrade, PeerId, Transport}; + use libp2p_mplex::MplexConfig; use libp2p_noise as noise; - use libp2p_tcp::TcpConfig; use libp2p_swarm::{Swarm, SwarmEvent}; - use libp2p_mplex::MplexConfig; + use libp2p_tcp::TcpConfig; - fn transport() -> (identity::PublicKey, transport::Boxed<(PeerId, StreamMuxerBox)>) { + fn transport() -> ( + identity::PublicKey, + transport::Boxed<(PeerId, StreamMuxerBox)>, + ) { let id_keys = identity::Keypair::generate_ed25519(); - let noise_keys = noise::Keypair::::new().into_authentic(&id_keys).unwrap(); + let noise_keys = noise::Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); let pubkey = id_keys.public(); let transport = TcpConfig::new() .nodelay(true) @@ -470,7 +471,8 @@ mod tests { let (pubkey, transport) = transport(); let protocol = Identify::new( IdentifyConfig::new("a".to_string(), pubkey.clone()) - .with_agent_version("b".to_string())); + .with_agent_version("b".to_string()), + ); let swarm = Swarm::new(transport, protocol, pubkey.to_peer_id()); (swarm, pubkey) }; @@ -479,12 +481,15 @@ mod tests { let (pubkey, transport) = transport(); let protocol = Identify::new( IdentifyConfig::new("c".to_string(), pubkey.clone()) - .with_agent_version("d".to_string())); + .with_agent_version("d".to_string()), + ); let swarm = Swarm::new(transport, protocol, pubkey.to_peer_id()); (swarm, pubkey) }; - swarm1.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); + swarm1 + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); let listen_addr = async_std::task::block_on(async { loop { @@ -509,7 +514,11 @@ mod tests { let swarm2_fut = swarm2.select_next_some(); pin_mut!(swarm2_fut); - match future::select(swarm1_fut, swarm2_fut).await.factor_second().0 { + match future::select(swarm1_fut, swarm2_fut) + .await + .factor_second() + .0 + { future::Either::Left(SwarmEvent::Behaviour(IdentifyEvent::Received { info, .. @@ -547,7 +556,8 @@ mod tests { let protocol = Identify::new( IdentifyConfig::new("a".to_string(), pubkey.clone()) // Delay identification requests so we can test the push protocol. - .with_initial_delay(Duration::from_secs(u32::MAX as u64))); + .with_initial_delay(Duration::from_secs(u32::MAX as u64)), + ); let swarm = Swarm::new(transport, protocol, pubkey.to_peer_id()); (swarm, pubkey) }; @@ -558,7 +568,8 @@ mod tests { IdentifyConfig::new("a".to_string(), pubkey.clone()) .with_agent_version("b".to_string()) // Delay identification requests so we can test the push protocol. - .with_initial_delay(Duration::from_secs(u32::MAX as u64))); + .with_initial_delay(Duration::from_secs(u32::MAX as u64)), + ); let swarm = Swarm::new(transport, protocol, pubkey.to_peer_id()); (swarm, pubkey) }; @@ -586,10 +597,15 @@ mod tests { { pin_mut!(swarm1_fut); pin_mut!(swarm2_fut); - match future::select(swarm1_fut, swarm2_fut).await.factor_second().0 { - future::Either::Left(SwarmEvent::Behaviour( - IdentifyEvent::Received { info, .. } - )) => { + match future::select(swarm1_fut, swarm2_fut) + .await + .factor_second() + .0 + { + future::Either::Left(SwarmEvent::Behaviour(IdentifyEvent::Received { + info, + .. + })) => { assert_eq!(info.public_key, pubkey2); assert_eq!(info.protocol_version, "a"); assert_eq!(info.agent_version, "b"); @@ -601,11 +617,13 @@ mod tests { // Once a connection is established, we can initiate an // active push below. } - _ => { continue } + _ => continue, } } - swarm2.behaviour_mut().push(std::iter::once(pubkey1.to_peer_id())); + swarm2 + .behaviour_mut() + .push(std::iter::once(pubkey1.to_peer_id())); } }) } diff --git a/protocols/identify/src/lib.rs b/protocols/identify/src/lib.rs index 48c0c651428..99456ed7001 100644 --- a/protocols/identify/src/lib.rs +++ b/protocols/identify/src/lib.rs @@ -47,4 +47,3 @@ mod protocol; mod structs_proto { include!(concat!(env!("OUT_DIR"), "/structs.rs")); } - diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index fafa5a37855..9604e660e9f 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -21,9 +21,8 @@ use crate::structs_proto; use futures::prelude::*; use libp2p_core::{ - Multiaddr, - PublicKey, - upgrade::{self, InboundUpgrade, OutboundUpgrade, UpgradeInfo} + upgrade::{self, InboundUpgrade, OutboundUpgrade, UpgradeInfo}, + Multiaddr, PublicKey, }; use log::{debug, trace}; use prost::Message; @@ -84,7 +83,7 @@ impl fmt::Debug for ReplySubstream { impl ReplySubstream where - T: AsyncWrite + Unpin + T: AsyncWrite + Unpin, { /// Sends back the requested information on the substream. /// @@ -158,17 +157,18 @@ where type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future { - send(socket, self.0.0).boxed() + send(socket, self.0 .0).boxed() } } async fn send(mut io: T, info: IdentifyInfo) -> io::Result<()> where - T: AsyncWrite + Unpin + T: AsyncWrite + Unpin, { trace!("Sending: {:?}", info); - let listen_addrs = info.listen_addrs + let listen_addrs = info + .listen_addrs .into_iter() .map(|addr| addr.to_vec()) .collect(); @@ -181,11 +181,13 @@ where public_key: Some(pubkey_bytes), listen_addrs, observed_addr: Some(info.observed_addr.to_vec()), - protocols: info.protocols + protocols: info.protocols, }; let mut bytes = Vec::with_capacity(message.encoded_len()); - message.encode(&mut bytes).expect("Vec provides capacity as needed"); + message + .encode(&mut bytes) + .expect("Vec provides capacity as needed"); upgrade::write_length_prefixed(&mut io, bytes).await?; io.close().await?; @@ -195,7 +197,7 @@ where async fn recv(mut socket: T) -> io::Result where - T: AsyncRead + AsyncWrite + Unpin + T: AsyncRead + AsyncWrite + Unpin, { socket.close().await?; @@ -207,7 +209,7 @@ where Ok(v) => v, Err(err) => { debug!("Invalid message: {:?}", err); - return Err(err) + return Err(err); } }; @@ -255,14 +257,14 @@ fn parse_proto_msg(msg: impl AsRef<[u8]>) -> Result { #[cfg(test)] mod tests { - use libp2p_tcp::TcpConfig; - use futures::{prelude::*, channel::oneshot}; + use super::*; + use futures::{channel::oneshot, prelude::*}; use libp2p_core::{ identity, + upgrade::{self, apply_inbound, apply_outbound}, Transport, - upgrade::{self, apply_outbound, apply_inbound} }; - use super::*; + use libp2p_tcp::TcpConfig; #[test] fn correct_transfer() { @@ -280,7 +282,9 @@ mod tests { .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener.next().await + let addr = listener + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -288,14 +292,20 @@ mod tests { tx.send(addr).unwrap(); let socket = listener - .next().await.unwrap().unwrap() - .into_upgrade().unwrap() - .0.await.unwrap(); + .next() + .await + .unwrap() + .unwrap() + .into_upgrade() + .unwrap() + .0 + .await + .unwrap(); let sender = apply_inbound(socket, IdentifyProtocol).await.unwrap(); - sender.send( - IdentifyInfo { + sender + .send(IdentifyInfo { public_key: send_pubkey, protocol_version: "proto_version".to_owned(), agent_version: "agent_version".to_owned(), @@ -305,27 +315,36 @@ mod tests { ], protocols: vec!["proto1".to_string(), "proto2".to_string()], observed_addr: "/ip4/100.101.102.103/tcp/5000".parse().unwrap(), - }, - ).await.unwrap(); + }) + .await + .unwrap(); }); async_std::task::block_on(async move { let transport = TcpConfig::new(); let socket = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); - let info = apply_outbound( - socket, - IdentifyProtocol, - upgrade::Version::V1 - ).await.unwrap(); - assert_eq!(info.observed_addr, "/ip4/100.101.102.103/tcp/5000".parse().unwrap()); + let info = apply_outbound(socket, IdentifyProtocol, upgrade::Version::V1) + .await + .unwrap(); + assert_eq!( + info.observed_addr, + "/ip4/100.101.102.103/tcp/5000".parse().unwrap() + ); assert_eq!(info.public_key, recv_pubkey); assert_eq!(info.protocol_version, "proto_version"); assert_eq!(info.agent_version, "agent_version"); - assert_eq!(info.listen_addrs, - &["/ip4/80.81.82.83/tcp/500".parse().unwrap(), - "/ip6/::1/udp/1000".parse().unwrap()]); - assert_eq!(info.protocols, &["proto1".to_string(), "proto2".to_string()]); + assert_eq!( + info.listen_addrs, + &[ + "/ip4/80.81.82.83/tcp/500".parse().unwrap(), + "/ip6/::1/udp/1000".parse().unwrap() + ] + ); + assert_eq!( + info.protocols, + &["proto1".to_string(), "proto2".to_string()] + ); bg_task.await; }); diff --git a/protocols/kad/build.rs b/protocols/kad/build.rs index abae8bdd169..f05e9e03190 100644 --- a/protocols/kad/build.rs +++ b/protocols/kad/build.rs @@ -19,6 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/dht.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/dht.proto"], &["src"]).unwrap(); } - diff --git a/protocols/kad/src/addresses.rs b/protocols/kad/src/addresses.rs index b0106a6f83d..f5bdd4d0fbc 100644 --- a/protocols/kad/src/addresses.rs +++ b/protocols/kad/src/addresses.rs @@ -65,9 +65,9 @@ impl Addresses { /// /// An address should only be removed if is determined to be invalid or /// otherwise unreachable. - pub fn remove(&mut self, addr: &Multiaddr) -> Result<(),()> { + pub fn remove(&mut self, addr: &Multiaddr) -> Result<(), ()> { if self.addrs.len() == 1 { - return Err(()) + return Err(()); } if let Some(pos) = self.addrs.iter().position(|a| a == addr) { @@ -100,7 +100,7 @@ impl Addresses { pub fn replace(&mut self, old: &Multiaddr, new: &Multiaddr) -> bool { if let Some(a) = self.addrs.iter_mut().find(|a| *a == old) { *a = new.clone(); - return true + return true; } false @@ -109,8 +109,6 @@ impl Addresses { impl fmt::Debug for Addresses { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list() - .entries(self.addrs.iter()) - .finish() + f.debug_list().entries(self.addrs.iter()).finish() } } diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index b0a2f3ee8db..4d7e3207334 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -22,37 +22,37 @@ mod test; -use crate::K_VALUE; use crate::addresses::Addresses; use crate::handler::{ - KademliaHandlerProto, - KademliaHandlerConfig, + KademliaHandlerConfig, KademliaHandlerEvent, KademliaHandlerIn, KademliaHandlerProto, KademliaRequestId, - KademliaHandlerEvent, - KademliaHandlerIn }; use crate::jobs::*; use crate::kbucket::{self, Distance, KBucketsTable, NodeStatus}; -use crate::protocol::{KademliaProtocolConfig, KadConnectionType, KadPeer}; -use crate::query::{Query, QueryId, QueryPool, QueryConfig, QueryPoolState}; -use crate::record::{self, store::{self, RecordStore}, Record, ProviderRecord}; +use crate::protocol::{KadConnectionType, KadPeer, KademliaProtocolConfig}; +use crate::query::{Query, QueryConfig, QueryId, QueryPool, QueryPoolState}; +use crate::record::{ + self, + store::{self, RecordStore}, + ProviderRecord, Record, +}; +use crate::K_VALUE; use fnv::{FnvHashMap, FnvHashSet}; -use libp2p_core::{ConnectedPoint, Multiaddr, PeerId, connection::{ConnectionId, ListenerId}}; +use libp2p_core::{ + connection::{ConnectionId, ListenerId}, + ConnectedPoint, Multiaddr, PeerId, +}; use libp2p_swarm::{ - DialPeerCondition, - NetworkBehaviour, - NetworkBehaviourAction, - NotifyHandler, - PollParameters, + DialPeerCondition, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters, }; -use log::{info, debug, warn}; +use log::{debug, info, warn}; use smallvec::SmallVec; -use std::{borrow::Cow, error, time::Duration}; -use std::collections::{HashSet, VecDeque, BTreeMap}; +use std::collections::{BTreeMap, HashSet, VecDeque}; use std::fmt; use std::num::NonZeroUsize; use std::task::{Context, Poll}; use std::vec; +use std::{borrow::Cow, error, time::Duration}; use wasm_timer::Instant; pub use crate::query::QueryStats; @@ -356,7 +356,7 @@ impl KademliaConfig { impl Kademlia where - for<'a> TStore: RecordStore<'a> + for<'a> TStore: RecordStore<'a>, { /// Creates a new `Kademlia` network behaviour with a default configuration. pub fn new(id: PeerId, store: TStore) -> Self { @@ -375,12 +375,14 @@ where let put_record_job = config .record_replication_interval .or(config.record_publication_interval) - .map(|interval| PutRecordJob::new( - id, - interval, - config.record_publication_interval, - config.record_ttl, - )); + .map(|interval| { + PutRecordJob::new( + id, + interval, + config.record_publication_interval, + config.record_ttl, + ) + }); let add_provider_job = config .provider_publication_interval @@ -406,42 +408,46 @@ where /// Gets an iterator over immutable references to all running queries. pub fn iter_queries(&self) -> impl Iterator> { - self.queries.iter().filter_map(|query| + self.queries.iter().filter_map(|query| { if !query.is_finished() { Some(QueryRef { query }) } else { None - }) + } + }) } /// Gets an iterator over mutable references to all running queries. pub fn iter_queries_mut(&mut self) -> impl Iterator> { - self.queries.iter_mut().filter_map(|query| + self.queries.iter_mut().filter_map(|query| { if !query.is_finished() { Some(QueryMut { query }) } else { None - }) + } + }) } /// Gets an immutable reference to a running query, if it exists. pub fn query(&self, id: &QueryId) -> Option> { - self.queries.get(id).and_then(|query| + self.queries.get(id).and_then(|query| { if !query.is_finished() { Some(QueryRef { query }) } else { None - }) + } + }) } /// Gets a mutable reference to a running query, if it exists. pub fn query_mut<'a>(&'a mut self, id: &QueryId) -> Option> { - self.queries.get_mut(id).and_then(|query| + self.queries.get_mut(id).and_then(|query| { if !query.is_finished() { Some(QueryMut { query }) } else { None - }) + } + }) } /// Adds a known listen address of a peer participating in the DHT to the @@ -466,18 +472,20 @@ where match self.kbuckets.entry(&key) { kbucket::Entry::Present(mut entry, _) => { if entry.value().insert(address) { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutingUpdated { - peer: *peer, - is_new_peer: false, - addresses: entry.value().clone(), - old_peer: None, - bucket_range: self.kbuckets - .bucket(&key) - .map(|b| b.range()) - .expect("Not kbucket::Entry::SelfEntry."), - } - )) + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutingUpdated { + peer: *peer, + is_new_peer: false, + addresses: entry.value().clone(), + old_peer: None, + bucket_range: self + .kbuckets + .bucket(&key) + .map(|b| b.range()) + .expect("Not kbucket::Entry::SelfEntry."), + }, + )) } RoutingUpdate::Success } @@ -487,41 +495,43 @@ where } kbucket::Entry::Absent(entry) => { let addresses = Addresses::new(address); - let status = - if self.connected_peers.contains(peer) { - NodeStatus::Connected - } else { - NodeStatus::Disconnected - }; + let status = if self.connected_peers.contains(peer) { + NodeStatus::Connected + } else { + NodeStatus::Disconnected + }; match entry.insert(addresses.clone(), status) { kbucket::InsertResult::Inserted => { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutingUpdated { - peer: *peer, - is_new_peer: true, - addresses, - old_peer: None, - bucket_range: self.kbuckets - .bucket(&key) - .map(|b| b.range()) - .expect("Not kbucket::Entry::SelfEntry."), - } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutingUpdated { + peer: *peer, + is_new_peer: true, + addresses, + old_peer: None, + bucket_range: self + .kbuckets + .bucket(&key) + .map(|b| b.range()) + .expect("Not kbucket::Entry::SelfEntry."), + }, + )); RoutingUpdate::Success - }, + } kbucket::InsertResult::Full => { debug!("Bucket full. Peer not added to routing table: {}", peer); RoutingUpdate::Failed - }, + } kbucket::InsertResult::Pending { disconnected } => { - self.queued_events.push_back(NetworkBehaviourAction::DialPeer { - peer_id: disconnected.into_preimage(), - condition: DialPeerCondition::Disconnected - }); + self.queued_events + .push_back(NetworkBehaviourAction::DialPeer { + peer_id: disconnected.into_preimage(), + condition: DialPeerCondition::Disconnected, + }); RoutingUpdate::Pending - }, + } } - }, + } kbucket::Entry::SelfEntry => RoutingUpdate::Failed, } } @@ -536,9 +546,11 @@ where /// /// If the given peer or address is not in the routing table, /// this is a no-op. - pub fn remove_address(&mut self, peer: &PeerId, address: &Multiaddr) - -> Option, Addresses>> - { + pub fn remove_address( + &mut self, + peer: &PeerId, + address: &Multiaddr, + ) -> Option, Addresses>> { let key = kbucket::Key::from(*peer); match self.kbuckets.entry(&key) { kbucket::Entry::Present(mut entry, _) => { @@ -555,9 +567,7 @@ where None } } - kbucket::Entry::Absent(..) | kbucket::Entry::SelfEntry => { - None - } + kbucket::Entry::Absent(..) | kbucket::Entry::SelfEntry => None, } } @@ -565,37 +575,34 @@ where /// /// Returns `None` if the peer was not in the routing table, /// not even pending insertion. - pub fn remove_peer(&mut self, peer: &PeerId) - -> Option, Addresses>> - { + pub fn remove_peer( + &mut self, + peer: &PeerId, + ) -> Option, Addresses>> { let key = kbucket::Key::from(*peer); match self.kbuckets.entry(&key) { - kbucket::Entry::Present(entry, _) => { - Some(entry.remove()) - } - kbucket::Entry::Pending(entry, _) => { - Some(entry.remove()) - } - kbucket::Entry::Absent(..) | kbucket::Entry::SelfEntry => { - None - } + kbucket::Entry::Present(entry, _) => Some(entry.remove()), + kbucket::Entry::Pending(entry, _) => Some(entry.remove()), + kbucket::Entry::Absent(..) | kbucket::Entry::SelfEntry => None, } } /// Returns an iterator over all non-empty buckets in the routing table. - pub fn kbuckets(&mut self) - -> impl Iterator, Addresses>> - { + pub fn kbuckets( + &mut self, + ) -> impl Iterator, Addresses>> { self.kbuckets.iter().filter(|b| !b.is_empty()) } /// Returns the k-bucket for the distance to the given key. /// /// Returns `None` if the given key refers to the local key. - pub fn kbucket(&mut self, key: K) - -> Option, Addresses>> + pub fn kbucket( + &mut self, + key: K, + ) -> Option, Addresses>> where - K: Into> + Clone + K: Into> + Clone, { self.kbuckets.bucket(&key.into()) } @@ -606,9 +613,11 @@ where /// [`KademliaEvent::OutboundQueryCompleted{QueryResult::GetClosestPeers}`]. pub fn get_closest_peers(&mut self, key: K) -> QueryId where - K: Into> + Into> + Clone + K: Into> + Into> + Clone, { - let info = QueryInfo::GetClosestPeers { key: key.clone().into() }; + let info = QueryInfo::GetClosestPeers { + key: key.clone().into(), + }; let target: kbucket::Key = key.into(); let peers = self.kbuckets.closest_keys(&target); let inner = QueryInner::new(info); @@ -627,7 +636,10 @@ where if record.is_expired(Instant::now()) { self.store.remove(key) } else { - records.push(PeerRecord{ peer: None, record: record.into_owned()}); + records.push(PeerRecord { + peer: None, + record: record.into_owned(), + }); } } @@ -669,11 +681,16 @@ where /// does not update the record's expiration in local storage, thus a given record /// with an explicit expiration will always expire at that instant and until then /// is subject to regular (re-)replication and (re-)publication. - pub fn put_record(&mut self, mut record: Record, quorum: Quorum) -> Result { + pub fn put_record( + &mut self, + mut record: Record, + quorum: Quorum, + ) -> Result { record.publisher = Some(*self.kbuckets.local_key().preimage()); self.store.put(record.clone())?; - record.expires = record.expires.or_else(|| - self.record_ttl.map(|ttl| Instant::now() + ttl)); + record.expires = record + .expires + .or_else(|| self.record_ttl.map(|ttl| Instant::now() + ttl)); let quorum = quorum.eval(self.queries.config().replication_factor); let target = kbucket::Key::new(record.key.clone()); let peers = self.kbuckets.closest_keys(&target); @@ -682,7 +699,7 @@ where context, record, quorum, - phase: PutRecordPhase::GetClosestPeers + phase: PutRecordPhase::GetClosestPeers, }; let inner = QueryInner::new(info); Ok(self.queries.add_iter_closest(target.clone(), peers, inner)) @@ -710,7 +727,7 @@ where /// > caching or for other reasons. pub fn put_record_to(&mut self, mut record: Record, peers: I, quorum: Quorum) -> QueryId where - I: ExactSizeIterator + I: ExactSizeIterator, { let quorum = if peers.len() > 0 { quorum.eval(NonZeroUsize::new(peers.len()).expect("> 0")) @@ -720,8 +737,9 @@ where // introducing a new kind of error. NonZeroUsize::new(1).expect("1 > 0") }; - record.expires = record.expires.or_else(|| - self.record_ttl.map(|ttl| Instant::now() + ttl)); + record.expires = record + .expires + .or_else(|| self.record_ttl.map(|ttl| Instant::now() + ttl)); let context = PutRecordContext::Custom; let info = QueryInfo::PutRecord { context, @@ -729,8 +747,8 @@ where quorum, phase: PutRecordPhase::PutRecord { success: Vec::new(), - get_closest_peers_stats: QueryStats::empty() - } + get_closest_peers_stats: QueryStats::empty(), + }, }; let inner = QueryInner::new(info); self.queries.add_fixed(peers, inner) @@ -781,7 +799,7 @@ where let local_key = self.kbuckets.local_key().clone(); let info = QueryInfo::Bootstrap { peer: *local_key.preimage(), - remaining: None + remaining: None, }; let peers = self.kbuckets.closest_keys(&local_key).collect::>(); if peers.is_empty() { @@ -822,7 +840,8 @@ where let record = ProviderRecord::new( key.clone(), *self.kbuckets.local_key().preimage(), - local_addrs); + local_addrs, + ); self.store.add_provider(record)?; let target = kbucket::Key::new(key.clone()); let peers = self.kbuckets.closest_keys(&target); @@ -830,7 +849,7 @@ where let info = QueryInfo::AddProvider { context, key, - phase: AddProviderPhase::GetClosestPeers + phase: AddProviderPhase::GetClosestPeers, }; let inner = QueryInner::new(info); let id = self.queries.add_iter_closest(target.clone(), peers, inner); @@ -842,7 +861,8 @@ where /// This is a local operation. The local node will still be considered as a /// provider for the key by other nodes until these provider records expire. pub fn stop_providing(&mut self, key: &record::Key) { - self.store.remove_provider(key, self.kbuckets.local_key().preimage()); + self.store + .remove_provider(key, self.kbuckets.local_key().preimage()); } /// Performs a lookup for providers of a value to the given key. @@ -863,15 +883,19 @@ where /// Processes discovered peers from a successful request in an iterative `Query`. fn discovered<'a, I>(&'a mut self, query_id: &QueryId, source: &PeerId, peers: I) where - I: Iterator + Clone + I: Iterator + Clone, { let local_id = self.kbuckets.local_key().preimage(); let others_iter = peers.filter(|p| &p.node_id != local_id); if let Some(query) = self.queries.get_mut(query_id) { log::trace!("Request to {:?} in query {:?} succeeded.", source, query_id); for peer in others_iter.clone() { - log::trace!("Peer {:?} reported by {:?} in query {:?}.", - peer, source, query_id); + log::trace!( + "Peer {:?} reported by {:?} in query {:?}.", + peer, + source, + query_id + ); let addrs = peer.multiaddrs.iter().cloned().collect(); query.inner.addresses.insert(peer.node_id, addrs); } @@ -882,7 +906,11 @@ where /// Finds the closest peers to a `target` in the context of a request by /// the `source` peer, such that the `source` peer is never included in the /// result. - fn find_closest(&mut self, target: &kbucket::Key, source: &PeerId) -> Vec { + fn find_closest( + &mut self, + target: &kbucket::Key, + source: &PeerId, + ) -> Vec { if target == self.kbuckets.local_key() { Vec::new() } else { @@ -900,9 +928,10 @@ where let kbuckets = &mut self.kbuckets; let connected = &mut self.connected_peers; let local_addrs = &self.local_addrs; - self.store.providers(key) + self.store + .providers(key) .into_iter() - .filter_map(move |p| + .filter_map(move |p| { if &p.provider != source { let node_id = p.provider; let multiaddrs = p.addresses; @@ -922,21 +951,23 @@ where Some(local_addrs.iter().cloned().collect::>()) } else { let key = kbucket::Key::from(node_id); - kbuckets.entry(&key).view().map(|e| e.node.value.clone().into_vec()) + kbuckets + .entry(&key) + .view() + .map(|e| e.node.value.clone().into_vec()) } } else { Some(multiaddrs) } - .map(|multiaddrs| { - KadPeer { - node_id, - multiaddrs, - connection_ty, - } + .map(|multiaddrs| KadPeer { + node_id, + multiaddrs, + connection_ty, }) } else { None - }) + } + }) .take(self.queries.config().replication_factor.get()) .collect() } @@ -946,7 +977,7 @@ where let info = QueryInfo::AddProvider { context, key: key.clone(), - phase: AddProviderPhase::GetClosestPeers + phase: AddProviderPhase::GetClosestPeers, }; let target = kbucket::Key::new(key); let peers = self.kbuckets.closest_keys(&target); @@ -960,14 +991,22 @@ where let target = kbucket::Key::new(record.key.clone()); let peers = self.kbuckets.closest_keys(&target); let info = QueryInfo::PutRecord { - record, quorum, context, phase: PutRecordPhase::GetClosestPeers + record, + quorum, + context, + phase: PutRecordPhase::GetClosestPeers, }; let inner = QueryInner::new(info); self.queries.add_iter_closest(target.clone(), peers, inner); } /// Updates the routing table with a new connection status and address of a peer. - fn connection_updated(&mut self, peer: PeerId, address: Option, new_status: NodeStatus) { + fn connection_updated( + &mut self, + peer: PeerId, + address: Option, + new_status: NodeStatus, + ) { let key = kbucket::Key::from(peer); match self.kbuckets.entry(&key) { kbucket::Entry::Present(mut entry, old_status) => { @@ -976,21 +1015,23 @@ where } if let Some(address) = address { if entry.value().insert(address) { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutingUpdated { - peer, - is_new_peer: false, - addresses: entry.value().clone(), - old_peer: None, - bucket_range: self.kbuckets - .bucket(&key) - .map(|b| b.range()) - .expect("Not kbucket::Entry::SelfEntry."), - } - )) + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutingUpdated { + peer, + is_new_peer: false, + addresses: entry.value().clone(), + old_peer: None, + bucket_range: self + .kbuckets + .bucket(&key) + .map(|b| b.range()) + .expect("Not kbucket::Entry::SelfEntry."), + }, + )) } } - }, + } kbucket::Entry::Pending(mut entry, old_status) => { if let Some(address) = address { @@ -999,23 +1040,25 @@ where if old_status != new_status { entry.update(new_status); } - }, + } kbucket::Entry::Absent(entry) => { // Only connected nodes with a known address are newly inserted. if new_status != NodeStatus::Connected { - return + return; } match (address, self.kbucket_inserts) { (None, _) => { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::UnroutablePeer { peer } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::UnroutablePeer { peer }, + )); } (Some(a), KademliaBucketInserts::Manual) => { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutablePeer { peer, address: a } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutablePeer { peer, address: a }, + )); } (Some(a), KademliaBucketInserts::OnConnected) => { let addresses = Addresses::new(a); @@ -1026,26 +1069,31 @@ where is_new_peer: true, addresses, old_peer: None, - bucket_range: self.kbuckets + bucket_range: self + .kbuckets .bucket(&key) .map(|b| b.range()) .expect("Not kbucket::Entry::SelfEntry."), }; - self.queued_events.push_back( - NetworkBehaviourAction::GenerateEvent(event)); - }, + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent(event)); + } kbucket::InsertResult::Full => { debug!("Bucket full. Peer not added to routing table: {}", peer); let address = addresses.first().clone(); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutablePeer { peer, address } - )); - }, + self.queued_events.push_back( + NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutablePeer { peer, address }, + ), + ); + } kbucket::InsertResult::Pending { disconnected } => { let address = addresses.first().clone(); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::PendingRoutablePeer { peer, address } - )); + self.queued_events.push_back( + NetworkBehaviourAction::GenerateEvent( + KademliaEvent::PendingRoutablePeer { peer, address }, + ), + ); // `disconnected` might already be in the process of re-connecting. // In other words `disconnected` might have already re-connected but @@ -1054,24 +1102,27 @@ where // // Only try dialing peer if not currently connected. if !self.connected_peers.contains(disconnected.preimage()) { - self.queued_events.push_back(NetworkBehaviourAction::DialPeer { - peer_id: disconnected.into_preimage(), - condition: DialPeerCondition::Disconnected - }) + self.queued_events + .push_back(NetworkBehaviourAction::DialPeer { + peer_id: disconnected.into_preimage(), + condition: DialPeerCondition::Disconnected, + }) } - }, + } } } } - }, + } _ => {} } } /// Handles a finished (i.e. successful) query. - fn query_finished(&mut self, q: Query, params: &mut impl PollParameters) - -> Option - { + fn query_finished( + &mut self, + q: Query, + params: &mut impl PollParameters, + ) -> Option { let query_id = q.id(); log::trace!("Query {:?} finished.", query_id); let result = q.into_result(); @@ -1084,7 +1135,8 @@ where // a bucket refresh should be performed for every bucket farther away than // the first non-empty bucket (which are most likely no more than the last // few, i.e. farthest, buckets). - self.kbuckets.iter() + self.kbuckets + .iter() .skip_while(|b| b.is_empty()) .skip(1) // Skip the bucket with the closest neighbour. .map(|b| { @@ -1102,7 +1154,7 @@ where // Pr(bucket-252) = 1 - (15/16)^16 ~= 0.64 // ... let mut target = kbucket::Key::from(PeerId::random()); - for _ in 0 .. 16 { + for _ in 0..16 { let d = local_key.distance(&target); if b.contains(&d) { break; @@ -1110,7 +1162,9 @@ where target = kbucket::Key::from(PeerId::random()); } target - }).collect::>().into_iter() + }) + .collect::>() + .into_iter() }); let num_remaining = remaining.len() as u32; @@ -1118,48 +1172,49 @@ where if let Some(target) = remaining.next() { let info = QueryInfo::Bootstrap { peer: target.clone().into_preimage(), - remaining: Some(remaining) + remaining: Some(remaining), }; let peers = self.kbuckets.closest_keys(&target); let inner = QueryInner::new(info); - self.queries.continue_iter_closest(query_id, target.clone(), peers, inner); + self.queries + .continue_iter_closest(query_id, target.clone(), peers, inner); } Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::Bootstrap(Ok(BootstrapOk { peer, num_remaining })) + result: QueryResult::Bootstrap(Ok(BootstrapOk { + peer, + num_remaining, + })), }) } - QueryInfo::GetClosestPeers { key, .. } => { - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::GetClosestPeers(Ok( - GetClosestPeersOk { key, peers: result.peers.collect() } - )) - }) - } + QueryInfo::GetClosestPeers { key, .. } => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::GetClosestPeers(Ok(GetClosestPeersOk { + key, + peers: result.peers.collect(), + })), + }), QueryInfo::GetProviders { key, providers } => { Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::GetProviders(Ok( - GetProvidersOk { - key, - providers, - closest_peers: result.peers.collect() - } - )) + result: QueryResult::GetProviders(Ok(GetProvidersOk { + key, + providers, + closest_peers: result.peers.collect(), + })), }) } QueryInfo::AddProvider { context, key, - phase: AddProviderPhase::GetClosestPeers + phase: AddProviderPhase::GetClosestPeers, } => { let provider_id = *params.local_peer_id(); let external_addresses = params.external_addresses().map(|r| r.addr).collect(); @@ -1169,8 +1224,8 @@ where phase: AddProviderPhase::AddProvider { provider_id, external_addresses, - get_closest_peers_stats: result.stats - } + get_closest_peers_stats: result.stats, + }, }); self.queries.continue_fixed(query_id, result.peers, inner); None @@ -1179,28 +1234,32 @@ where QueryInfo::AddProvider { context, key, - phase: AddProviderPhase::AddProvider { get_closest_peers_stats, .. } - } => { - match context { - AddProviderContext::Publish => { - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: get_closest_peers_stats.merge(result.stats), - result: QueryResult::StartProviding(Ok(AddProviderOk { key })) - }) - } - AddProviderContext::Republish => { - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: get_closest_peers_stats.merge(result.stats), - result: QueryResult::RepublishProvider(Ok(AddProviderOk { key })) - }) - } - } - } + phase: + AddProviderPhase::AddProvider { + get_closest_peers_stats, + .. + }, + } => match context { + AddProviderContext::Publish => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: get_closest_peers_stats.merge(result.stats), + result: QueryResult::StartProviding(Ok(AddProviderOk { key })), + }), + AddProviderContext::Republish => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: get_closest_peers_stats.merge(result.stats), + result: QueryResult::RepublishProvider(Ok(AddProviderOk { key })), + }), + }, - QueryInfo::GetRecord { key, records, quorum, cache_candidates } => { - let results = if records.len() >= quorum.get() { // [not empty] + QueryInfo::GetRecord { + key, + records, + quorum, + cache_candidates, + } => { + let results = if records.len() >= quorum.get() { + // [not empty] if quorum.get() == 1 && !cache_candidates.is_empty() { // Cache the record at the closest node(s) to the key that // did not return the record. @@ -1213,25 +1272,33 @@ where quorum, phase: PutRecordPhase::PutRecord { success: vec![], - get_closest_peers_stats: QueryStats::empty() - } + get_closest_peers_stats: QueryStats::empty(), + }, }; let inner = QueryInner::new(info); - self.queries.add_fixed(cache_candidates.values().copied(), inner); + self.queries + .add_fixed(cache_candidates.values().copied(), inner); } - Ok(GetRecordOk { records, cache_candidates }) + Ok(GetRecordOk { + records, + cache_candidates, + }) } else if records.is_empty() { Err(GetRecordError::NotFound { key, - closest_peers: result.peers.collect() + closest_peers: result.peers.collect(), }) } else { - Err(GetRecordError::QuorumFailed { key, records, quorum }) + Err(GetRecordError::QuorumFailed { + key, + records, + quorum, + }) }; Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::GetRecord(results) + result: QueryResult::GetRecord(results), }) } @@ -1239,7 +1306,7 @@ where context, record, quorum, - phase: PutRecordPhase::GetClosestPeers + phase: PutRecordPhase::GetClosestPeers, } => { let info = QueryInfo::PutRecord { context, @@ -1247,8 +1314,8 @@ where quorum, phase: PutRecordPhase::PutRecord { success: vec![], - get_closest_peers_stats: result.stats - } + get_closest_peers_stats: result.stats, + }, }; let inner = QueryInner::new(info); self.queries.continue_fixed(query_id, result.peers, inner); @@ -1259,28 +1326,36 @@ where context, record, quorum, - phase: PutRecordPhase::PutRecord { success, get_closest_peers_stats } + phase: + PutRecordPhase::PutRecord { + success, + get_closest_peers_stats, + }, } => { let mk_result = |key: record::Key| { if success.len() >= quorum.get() { Ok(PutRecordOk { key }) } else { - Err(PutRecordError::QuorumFailed { key, quorum, success }) + Err(PutRecordError::QuorumFailed { + key, + quorum, + success, + }) } }; match context { - PutRecordContext::Publish | PutRecordContext::Custom => - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: get_closest_peers_stats.merge(result.stats), - result: QueryResult::PutRecord(mk_result(record.key)) - }), - PutRecordContext::Republish => + PutRecordContext::Publish | PutRecordContext::Custom => { Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: get_closest_peers_stats.merge(result.stats), - result: QueryResult::RepublishRecord(mk_result(record.key)) - }), + result: QueryResult::PutRecord(mk_result(record.key)), + }) + } + PutRecordContext::Republish => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: get_closest_peers_stats.merge(result.stats), + result: QueryResult::RepublishRecord(mk_result(record.key)), + }), PutRecordContext::Replicate => { debug!("Record replicated: {:?}", record.key); None @@ -1300,7 +1375,10 @@ where log::trace!("Query {:?} timed out.", query_id); let result = query.into_result(); match result.inner.info { - QueryInfo::Bootstrap { peer, mut remaining } => { + QueryInfo::Bootstrap { + peer, + mut remaining, + } => { let num_remaining = remaining.as_ref().map(|r| r.len().saturating_sub(1) as u32); if let Some(mut remaining) = remaining.take() { @@ -1308,78 +1386,74 @@ where if let Some(target) = remaining.next() { let info = QueryInfo::Bootstrap { peer: target.clone().into_preimage(), - remaining: Some(remaining) + remaining: Some(remaining), }; let peers = self.kbuckets.closest_keys(&target); let inner = QueryInner::new(info); - self.queries.continue_iter_closest(query_id, target.clone(), peers, inner); + self.queries + .continue_iter_closest(query_id, target.clone(), peers, inner); } } Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::Bootstrap(Err( - BootstrapError::Timeout { peer, num_remaining } - )) + result: QueryResult::Bootstrap(Err(BootstrapError::Timeout { + peer, + num_remaining, + })), }) } - QueryInfo::AddProvider { context, key, .. } => - Some(match context { - AddProviderContext::Publish => - KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::StartProviding(Err( - AddProviderError::Timeout { key } - )) - }, - AddProviderContext::Republish => - KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::RepublishProvider(Err( - AddProviderError::Timeout { key } - )) - } - }), - - QueryInfo::GetClosestPeers { key } => { - Some(KademliaEvent::OutboundQueryCompleted { + QueryInfo::AddProvider { context, key, .. } => Some(match context { + AddProviderContext::Publish => KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::GetClosestPeers(Err( - GetClosestPeersError::Timeout { - key, - peers: result.peers.collect() - } - )) - }) - }, + result: QueryResult::StartProviding(Err(AddProviderError::Timeout { key })), + }, + AddProviderContext::Republish => KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::RepublishProvider(Err(AddProviderError::Timeout { key })), + }, + }), + + QueryInfo::GetClosestPeers { key } => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::GetClosestPeers(Err(GetClosestPeersError::Timeout { + key, + peers: result.peers.collect(), + })), + }), - QueryInfo::PutRecord { record, quorum, context, phase } => { + QueryInfo::PutRecord { + record, + quorum, + context, + phase, + } => { let err = Err(PutRecordError::Timeout { key: record.key, quorum, success: match phase { PutRecordPhase::GetClosestPeers => vec![], PutRecordPhase::PutRecord { ref success, .. } => success.clone(), - } + }, }); match context { - PutRecordContext::Publish | PutRecordContext::Custom => + PutRecordContext::Publish | PutRecordContext::Custom => { Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::PutRecord(err) - }), - PutRecordContext::Republish => - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::RepublishRecord(err) - }), + result: QueryResult::PutRecord(err), + }) + } + PutRecordContext::Republish => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::RepublishRecord(err), + }), PutRecordContext::Replicate => match phase { PutRecordPhase::GetClosestPeers => { warn!("Locating closest peers for replication failed: {:?}", err); @@ -1389,7 +1463,7 @@ where debug!("Replicating record failed: {:?}", err); None } - } + }, PutRecordContext::Cache => match phase { PutRecordPhase::GetClosestPeers => { // Caching a record at the closest peer to a key that did not return @@ -1401,32 +1475,37 @@ where debug!("Caching record failed: {:?}", err); None } - } + }, } } - QueryInfo::GetRecord { key, records, quorum, .. } => - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::GetRecord(Err( - GetRecordError::Timeout { key, records, quorum }, - )) - }), + QueryInfo::GetRecord { + key, + records, + quorum, + .. + } => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::GetRecord(Err(GetRecordError::Timeout { + key, + records, + quorum, + })), + }), - QueryInfo::GetProviders { key, providers } => + QueryInfo::GetProviders { key, providers } => { Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::GetProviders(Err( - GetProvidersError::Timeout { - key, - providers, - closest_peers: result.peers.collect() - } - )) + result: QueryResult::GetProviders(Err(GetProvidersError::Timeout { + key, + providers, + closest_peers: result.peers.collect(), + })), }) } + } } /// Processes a record received from a peer. @@ -1435,22 +1514,23 @@ where source: PeerId, connection: ConnectionId, request_id: KademliaRequestId, - mut record: Record + mut record: Record, ) { if record.publisher.as_ref() == Some(self.kbuckets.local_key().preimage()) { // If the (alleged) publisher is the local node, do nothing. The record of // the original publisher should never change as a result of replication // and the publisher is always assumed to have the "right" value. - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::PutRecordRes { - key: record.key, - value: record.value, - request_id, - }, - }); - return + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::PutRecordRes { + key: record.key, + value: record.value, + request_id, + }, + }); + return; } let now = Instant::now(); @@ -1463,7 +1543,9 @@ where let num_between = self.kbuckets.count_nodes_between(&target); let k = self.queries.config().replication_factor.get(); let num_beyond_k = (usize::max(k, num_between) - k) as u32; - let expiration = self.record_ttl.map(|ttl| now + exp_decrease(ttl, num_beyond_k)); + let expiration = self + .record_ttl + .map(|ttl| now + exp_decrease(ttl, num_beyond_k)); // The smaller TTL prevails. Only if neither TTL is set is the record // stored "forever". record.expires = record.expires.or(expiration).min(expiration); @@ -1491,16 +1573,21 @@ where // requirement to send back the value in the response, although this // is a waste of resources. match self.store.put(record.clone()) { - Ok(()) => debug!("Record stored: {:?}; {} bytes", record.key, record.value.len()), + Ok(()) => debug!( + "Record stored: {:?}; {} bytes", + record.key, + record.value.len() + ), Err(e) => { info!("Record not stored: {:?}", e); - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::Reset(request_id) - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::Reset(request_id), + }); - return + return; } } } @@ -1512,15 +1599,16 @@ where // closest nodes to the target. In addition returning // [`KademliaHandlerIn::PutRecordRes`] does not reveal any internal // information to a possibly malicious remote node. - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::PutRecordRes { - key: record.key, - value: record.value, - request_id, - }, - }) + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::PutRecordRes { + key: record.key, + value: record.value, + request_id, + }, + }) } /// Processes a provider record received from a peer. @@ -1593,14 +1681,19 @@ where fn inject_connected(&mut self, peer: &PeerId) { // Queue events for sending pending RPCs to the connected peer. // There can be only one pending RPC for a particular peer and query per definition. - for (peer_id, event) in self.queries.iter_mut().filter_map(|q| - q.inner.pending_rpcs.iter() + for (peer_id, event) in self.queries.iter_mut().filter_map(|q| { + q.inner + .pending_rpcs + .iter() .position(|(p, _)| p == peer) - .map(|p| q.inner.pending_rpcs.remove(p))) - { - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, event, handler: NotifyHandler::Any - }); + .map(|p| q.inner.pending_rpcs.remove(p)) + }) { + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id, + event, + handler: NotifyHandler::Any, + }); } self.connected_peers.insert(*peer); @@ -1611,14 +1704,17 @@ where peer: &PeerId, _: &ConnectionId, old: &ConnectedPoint, - new: &ConnectedPoint + new: &ConnectedPoint, ) { let (old, new) = (old.get_remote_address(), new.get_remote_address()); // Update routing table. if let Some(addrs) = self.kbuckets.entry(&kbucket::Key::from(*peer)).value() { if addrs.replace(old, new) { - debug!("Address '{}' replaced with '{}' for peer '{}'.", old, new, peer); + debug!( + "Address '{}' replaced with '{}' for peer '{}'.", + old, new, peer + ); } else { debug!( "Address '{}' not replaced with '{}' for peer '{}' as old address wasn't \ @@ -1663,7 +1759,7 @@ where &mut self, peer_id: Option<&PeerId>, addr: &Multiaddr, - err: &dyn error::Error + err: &dyn error::Error, ) { if let Some(peer_id) = peer_id { let key = kbucket::Key::from(*peer_id); @@ -1675,8 +1771,10 @@ where // of the error is not possible (and also not truly desirable or ergonomic). // The error passed in should rather be a dedicated enum. if addrs.remove(addr).is_ok() { - debug!("Address '{}' removed from peer '{}' due to error: {}.", - addr, peer_id, err); + debug!( + "Address '{}' removed from peer '{}' due to error: {}.", + addr, peer_id, err + ); } else { // Despite apparently having no reachable address (any longer), // the peer is kept in the routing table with the last address to avoid @@ -1687,8 +1785,10 @@ where // into the same bucket. This is handled transparently by the // `KBucketsTable` and takes effect through `KBucketsTable::take_applied_pending` // within `Kademlia::poll`. - debug!("Last remaining address '{}' of peer '{}' is unreachable: {}.", - addr, peer_id, err) + debug!( + "Last remaining address '{}' of peer '{}' is unreachable: {}.", + addr, peer_id, err + ) } } @@ -1718,7 +1818,7 @@ where &mut self, source: PeerId, connection: ConnectionId, - event: KademliaHandlerEvent + event: KademliaHandlerEvent, ) { match event { KademliaHandlerEvent::ProtocolConfirmed { endpoint } => { @@ -1737,20 +1837,24 @@ where KademliaHandlerEvent::FindNodeReq { key, request_id } => { let closer_peers = self.find_closest(&kbucket::Key::new(key), &source); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::FindNode { - num_closer_peers: closer_peers.len(), - }} - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::FindNode { + num_closer_peers: closer_peers.len(), + }, + }, + )); - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::FindNodeRes { - closer_peers, - request_id, - }, - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::FindNodeRes { + closer_peers, + request_id, + }, + }); } KademliaHandlerEvent::FindNodeRes { @@ -1764,22 +1868,26 @@ where let provider_peers = self.provider_peers(&key, &source); let closer_peers = self.find_closest(&kbucket::Key::new(key), &source); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::GetProvider { - num_closer_peers: closer_peers.len(), - num_provider_peers: provider_peers.len(), - }} - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::GetProvider { + num_closer_peers: closer_peers.len(), + num_provider_peers: provider_peers.len(), + }, + }, + )); - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::GetProvidersRes { - closer_peers, - provider_peers, - request_id, - }, - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::GetProvidersRes { + closer_peers, + provider_peers, + request_id, + }, + }); } KademliaHandlerEvent::GetProvidersRes { @@ -1790,9 +1898,7 @@ where let peers = closer_peers.iter().chain(provider_peers.iter()); self.discovered(&user_data, &source, peers); if let Some(query) = self.queries.get_mut(&user_data) { - if let QueryInfo::GetProviders { - providers, .. - } = &mut query.inner.info { + if let QueryInfo::GetProviders { providers, .. } = &mut query.inner.info { for peer in provider_peers { providers.insert(peer.node_id); } @@ -1801,8 +1907,12 @@ where } KademliaHandlerEvent::QueryError { user_data, error } => { - log::debug!("Request to {:?} in query {:?} failed with {:?}", - source, user_data, error); + log::debug!( + "Request to {:?} in query {:?} failed with {:?}", + source, + user_data, + error + ); // If the query to which the error relates is still active, // signal the failure w.r.t. `source`. if let Some(query) = self.queries.get_mut(&user_data) { @@ -1813,14 +1923,17 @@ where KademliaHandlerEvent::AddProvider { key, provider } => { // Only accept a provider record from a legitimate peer. if provider.node_id != source { - return + return; } self.provider_received(key, provider); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::AddProvider {} } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::AddProvider {}, + }, + )); } KademliaHandlerEvent::GetRecord { key, request_id } => { @@ -1833,28 +1946,32 @@ where } else { Some(record.into_owned()) } - }, - None => None + } + None => None, }; let closer_peers = self.find_closest(&kbucket::Key::new(key), &source); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::GetRecord { - num_closer_peers: closer_peers.len(), - present_locally: record.is_some(), - }} - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::GetRecord { + num_closer_peers: closer_peers.len(), + present_locally: record.is_some(), + }, + }, + )); - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::GetRecordRes { - record, - closer_peers, - request_id, - }, - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::GetRecordRes { + record, + closer_peers, + request_id, + }, + }); } KademliaHandlerEvent::GetRecordRes { @@ -1864,17 +1981,25 @@ where } => { if let Some(query) = self.queries.get_mut(&user_data) { if let QueryInfo::GetRecord { - key, records, quorum, cache_candidates - } = &mut query.inner.info { + key, + records, + quorum, + cache_candidates, + } = &mut query.inner.info + { if let Some(record) = record { - records.push(PeerRecord{ peer: Some(source), record }); + records.push(PeerRecord { + peer: Some(source), + record, + }); let quorum = quorum.get(); if records.len() >= quorum { // Desired quorum reached. The query may finish. See // [`Query::try_finish`] for details. - let peers = records.iter() - .filter_map(|PeerRecord{ peer, .. }| peer.as_ref()) + let peers = records + .iter() + .filter_map(|PeerRecord { peer, .. }| peer.as_ref()) .cloned() .collect::>(); let finished = query.try_finish(peers.iter()); @@ -1882,7 +2007,10 @@ where debug!( "GetRecord query ({:?}) reached quorum ({}/{}) with \ response from peer {} but could not yet finish.", - user_data, peers.len(), quorum, source, + user_data, + peers.len(), + quorum, + source, ); } } @@ -1896,7 +2024,8 @@ where if cache_candidates.len() > max_peers as usize { // TODO: `pop_last()` would be nice once stabilised. // See https://github.com/rust-lang/rust/issues/62924. - let last = *cache_candidates.keys().next_back().expect("len > 0"); + let last = + *cache_candidates.keys().next_back().expect("len > 0"); cache_candidates.remove(&last); } } @@ -1907,25 +2036,26 @@ where self.discovered(&user_data, &source, closer_peers.iter()); } - KademliaHandlerEvent::PutRecord { - record, - request_id - } => { + KademliaHandlerEvent::PutRecord { record, request_id } => { self.record_received(source, connection, request_id, record); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::PutRecord {} } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::PutRecord {}, + }, + )); } - KademliaHandlerEvent::PutRecordRes { - user_data, .. - } => { + KademliaHandlerEvent::PutRecordRes { user_data, .. } => { if let Some(query) = self.queries.get_mut(&user_data) { query.on_success(&source, vec![]); if let QueryInfo::PutRecord { - phase: PutRecordPhase::PutRecord { success, .. }, quorum, .. - } = &mut query.inner.info { + phase: PutRecordPhase::PutRecord { success, .. }, + quorum, + .. + } = &mut query.inner.info + { success.push(source); let quorum = quorum.get(); @@ -1936,7 +2066,10 @@ where debug!( "PutRecord query ({:?}) reached quorum ({}/{}) with response \ from peer {} but could not yet finish.", - user_data, peers.len(), quorum, source, + user_data, + peers.len(), + quorum, + source, ); } } @@ -1960,12 +2093,11 @@ where } } - fn poll(&mut self, cx: &mut Context<'_>, parameters: &mut impl PollParameters) -> Poll< - NetworkBehaviourAction< - KademliaHandlerIn, - Self::OutEvent, - >, - > { + fn poll( + &mut self, + cx: &mut Context<'_>, + parameters: &mut impl PollParameters, + ) -> Poll, Self::OutEvent>> { let now = Instant::now(); // Calculate the available capacity for queries triggered by background jobs. @@ -1974,11 +2106,11 @@ where // Run the periodic provider announcement job. if let Some(mut job) = self.add_provider_job.take() { let num = usize::min(JOBS_MAX_NEW_QUERIES, jobs_query_capacity); - for _ in 0 .. num { + for _ in 0..num { if let Poll::Ready(r) = job.poll(cx, &mut self.store, now) { self.start_add_provider(r.key, AddProviderContext::Republish) } else { - break + break; } } jobs_query_capacity -= num; @@ -1988,16 +2120,17 @@ where // Run the periodic record replication / publication job. if let Some(mut job) = self.put_record_job.take() { let num = usize::min(JOBS_MAX_NEW_QUERIES, jobs_query_capacity); - for _ in 0 .. num { + for _ in 0..num { if let Poll::Ready(r) = job.poll(cx, &mut self.store, now) { - let context = if r.publisher.as_ref() == Some(self.kbuckets.local_key().preimage()) { - PutRecordContext::Republish - } else { - PutRecordContext::Replicate - }; + let context = + if r.publisher.as_ref() == Some(self.kbuckets.local_key().preimage()) { + PutRecordContext::Republish + } else { + PutRecordContext::Replicate + }; self.start_put_record(r, Quorum::All, context) } else { - break + break; } } self.put_record_job = Some(job); @@ -2013,7 +2146,8 @@ where if let Some(entry) = self.kbuckets.take_applied_pending() { let kbucket::Node { key, value } = entry.inserted; let event = KademliaEvent::RoutingUpdated { - bucket_range: self.kbuckets + bucket_range: self + .kbuckets .bucket(&key) .map(|b| b.range()) .expect("Self to never be applied from pending."), @@ -2022,7 +2156,7 @@ where addresses: value, old_peer: entry.evicted.map(|n| n.key.into_preimage()), }; - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } // Look for a finished query. @@ -2030,12 +2164,12 @@ where match self.queries.poll(now) { QueryPoolState::Finished(q) => { if let Some(event) = self.query_finished(q, parameters) { - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } } QueryPoolState::Timeout(q) => { if let Some(event) = self.query_timeout(q) { - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } } QueryPoolState::Waiting(Some((query, peer_id))) => { @@ -2048,18 +2182,24 @@ where if let QueryInfo::AddProvider { phase: AddProviderPhase::AddProvider { .. }, .. - } = &query.inner.info { + } = &query.inner.info + { query.on_success(&peer_id, vec![]) } if self.connected_peers.contains(&peer_id) { - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, event, handler: NotifyHandler::Any - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id, + event, + handler: NotifyHandler::Any, + }); } else if &peer_id != self.kbuckets.local_key().preimage() { query.inner.pending_rpcs.push((peer_id, event)); - self.queued_events.push_back(NetworkBehaviourAction::DialPeer { - peer_id, condition: DialPeerCondition::Disconnected - }); + self.queued_events + .push_back(NetworkBehaviourAction::DialPeer { + peer_id, + condition: DialPeerCondition::Disconnected, + }); } } QueryPoolState::Waiting(None) | QueryPoolState::Idle => break, @@ -2070,7 +2210,7 @@ where // If no new events have been queued either, signal `NotReady` to // be polled again later. if self.queued_events.is_empty() { - return Poll::Pending + return Poll::Pending; } } } @@ -2084,7 +2224,7 @@ pub enum Quorum { One, Majority, All, - N(NonZeroUsize) + N(NonZeroUsize), } impl Quorum { @@ -2094,7 +2234,7 @@ impl Quorum { Quorum::One => NonZeroUsize::new(1).expect("1 != 0"), Quorum::Majority => NonZeroUsize::new(total.get() / 2 + 1).expect("n + 1 != 0"), Quorum::All => total, - Quorum::N(n) => NonZeroUsize::min(total, *n) + Quorum::N(n) => NonZeroUsize::min(total, *n), } } } @@ -2122,9 +2262,7 @@ pub enum KademliaEvent { // Note on the difference between 'request' and 'query': A request is a // single request-response style exchange with a single remote peer. A query // is made of multiple requests across multiple remote peers. - InboundRequestServed { - request: InboundRequest, - }, + InboundRequestServed { request: InboundRequest }, /// An outbound query has produced a result. OutboundQueryCompleted { @@ -2133,7 +2271,7 @@ pub enum KademliaEvent { /// The result of the query. result: QueryResult, /// Execution statistics from the query. - stats: QueryStats + stats: QueryStats, }, /// The routing table has been updated with a new peer and / or @@ -2158,9 +2296,7 @@ pub enum KademliaEvent { /// /// If the peer is to be added to the routing table, a known /// listen address for the peer must be provided via [`Kademlia::add_address`]. - UnroutablePeer { - peer: PeerId - }, + UnroutablePeer { peer: PeerId }, /// A connection to a peer has been established for whom a listen address /// is known but the peer has not been added to the routing table either @@ -2173,10 +2309,7 @@ pub enum KademliaEvent { /// /// See [`Kademlia::kbucket`] for insight into the contents of /// the k-bucket of `peer`. - RoutablePeer { - peer: PeerId, - address: Multiaddr, - }, + RoutablePeer { peer: PeerId, address: Multiaddr }, /// A connection to a peer has been established for whom a listen address /// is known but the peer is only pending insertion into the routing table @@ -2189,19 +2322,14 @@ pub enum KademliaEvent { /// /// See [`Kademlia::kbucket`] for insight into the contents of /// the k-bucket of `peer`. - PendingRoutablePeer { - peer: PeerId, - address: Multiaddr, - } + PendingRoutablePeer { peer: PeerId, address: Multiaddr }, } /// Information about a received and handled inbound request. #[derive(Debug)] pub enum InboundRequest { /// Request for the list of nodes whose IDs are the closest to `key`. - FindNode { - num_closer_peers: usize, - }, + FindNode { num_closer_peers: usize }, /// Same as `FindNode`, but should also return the entries of the local /// providers list for this key. GetProvider { @@ -2278,18 +2406,18 @@ pub struct GetRecordOk { pub enum GetRecordError { NotFound { key: record::Key, - closest_peers: Vec + closest_peers: Vec, }, QuorumFailed { key: record::Key, records: Vec, - quorum: NonZeroUsize + quorum: NonZeroUsize, }, Timeout { key: record::Key, records: Vec, - quorum: NonZeroUsize - } + quorum: NonZeroUsize, + }, } impl GetRecordError { @@ -2319,7 +2447,7 @@ pub type PutRecordResult = Result; /// The successful result of [`Kademlia::put_record`]. #[derive(Debug, Clone)] pub struct PutRecordOk { - pub key: record::Key + pub key: record::Key, } /// The error result of [`Kademlia::put_record`]. @@ -2329,13 +2457,13 @@ pub enum PutRecordError { key: record::Key, /// [`PeerId`]s of the peers the record was successfully stored on. success: Vec, - quorum: NonZeroUsize + quorum: NonZeroUsize, }, Timeout { key: record::Key, /// [`PeerId`]s of the peers the record was successfully stored on. success: Vec, - quorum: NonZeroUsize + quorum: NonZeroUsize, }, } @@ -2374,7 +2502,7 @@ pub enum BootstrapError { Timeout { peer: PeerId, num_remaining: Option, - } + }, } /// The result of [`Kademlia::get_closest_peers`]. @@ -2384,16 +2512,13 @@ pub type GetClosestPeersResult = Result #[derive(Debug, Clone)] pub struct GetClosestPeersOk { pub key: Vec, - pub peers: Vec + pub peers: Vec, } /// The error result of [`Kademlia::get_closest_peers`]. #[derive(Debug, Clone)] pub enum GetClosestPeersError { - Timeout { - key: Vec, - peers: Vec - } + Timeout { key: Vec, peers: Vec }, } impl GetClosestPeersError { @@ -2421,7 +2546,7 @@ pub type GetProvidersResult = Result; pub struct GetProvidersOk { pub key: record::Key, pub providers: HashSet, - pub closest_peers: Vec + pub closest_peers: Vec, } /// The error result of [`Kademlia::get_providers`]. @@ -2430,8 +2555,8 @@ pub enum GetProvidersError { Timeout { key: record::Key, providers: HashSet, - closest_peers: Vec - } + closest_peers: Vec, + }, } impl GetProvidersError { @@ -2464,9 +2589,7 @@ pub struct AddProviderOk { #[derive(Debug)] pub enum AddProviderError { /// The query timed out. - Timeout { - key: record::Key, - }, + Timeout { key: record::Key }, } impl AddProviderError { @@ -2492,8 +2615,8 @@ impl From, Addresses>> for KadPeer { multiaddrs: e.node.value.into_vec(), connection_ty: match e.status { NodeStatus::Connected => KadConnectionType::Connected, - NodeStatus::Disconnected => KadConnectionType::NotConnected - } + NodeStatus::Disconnected => KadConnectionType::NotConnected, + }, } } } @@ -2510,7 +2633,7 @@ struct QueryInner { /// /// A request is pending if the targeted peer is not currently connected /// and these requests are sent as soon as a connection to the peer is established. - pending_rpcs: SmallVec<[(PeerId, KademliaHandlerIn); K_VALUE.get()]> + pending_rpcs: SmallVec<[(PeerId, KademliaHandlerIn); K_VALUE.get()]>, } impl QueryInner { @@ -2518,7 +2641,7 @@ impl QueryInner { QueryInner { info, addresses: Default::default(), - pending_rpcs: SmallVec::default() + pending_rpcs: SmallVec::default(), } } } @@ -2567,7 +2690,7 @@ pub enum QueryInfo { /// This is `None` if the initial self-lookup has not /// yet completed and `Some` with an exhausted iterator /// if bootstrapping is complete. - remaining: Option>> + remaining: Option>>, }, /// A query initiated by [`Kademlia::get_closest_peers`]. @@ -2639,16 +2762,18 @@ impl QueryInfo { key: key.to_vec(), user_data: query_id, }, - AddProviderPhase::AddProvider { provider_id, external_addresses, .. } => { - KademliaHandlerIn::AddProvider { - key: key.clone(), - provider: crate::protocol::KadPeer { - node_id: *provider_id, - multiaddrs: external_addresses.clone(), - connection_ty: crate::protocol::KadConnectionType::Connected, - } - } - } + AddProviderPhase::AddProvider { + provider_id, + external_addresses, + .. + } => KademliaHandlerIn::AddProvider { + key: key.clone(), + provider: crate::protocol::KadPeer { + node_id: *provider_id, + multiaddrs: external_addresses.clone(), + connection_ty: crate::protocol::KadConnectionType::Connected, + }, + }, }, QueryInfo::GetRecord { key, .. } => KademliaHandlerIn::GetRecord { key: key.clone(), @@ -2661,9 +2786,9 @@ impl QueryInfo { }, PutRecordPhase::PutRecord { .. } => KademliaHandlerIn::PutRecord { record: record.clone(), - user_data: query_id - } - } + user_data: query_id, + }, + }, } } } diff --git a/protocols/kad/src/behaviour/test.rs b/protocols/kad/src/behaviour/test.rs index 9bd640878d0..110c1e2221e 100644 --- a/protocols/kad/src/behaviour/test.rs +++ b/protocols/kad/src/behaviour/test.rs @@ -22,31 +22,30 @@ use super::*; -use crate::K_VALUE; use crate::kbucket::Distance; -use crate::record::{Key, store::MemoryStore}; -use futures::{ - prelude::*, - executor::block_on, - future::poll_fn, -}; +use crate::record::{store::MemoryStore, Key}; +use crate::K_VALUE; +use futures::{executor::block_on, future::poll_fn, prelude::*}; use futures_timer::Delay; use libp2p_core::{ connection::{ConnectedPoint, ConnectionId}, - PeerId, - Transport, identity, - transport::MemoryTransport, - multiaddr::{Protocol, Multiaddr, multiaddr}, - upgrade, + multiaddr::{multiaddr, Multiaddr, Protocol}, multihash::{Code, Multihash, MultihashDigest}, + transport::MemoryTransport, + upgrade, PeerId, Transport, }; use libp2p_noise as noise; use libp2p_swarm::{Swarm, SwarmEvent}; use libp2p_yamux as yamux; use quickcheck::*; -use rand::{Rng, random, thread_rng, rngs::StdRng, SeedableRng}; -use std::{collections::{HashSet, HashMap}, time::Duration, num::NonZeroUsize, u64}; +use rand::{random, rngs::StdRng, thread_rng, Rng, SeedableRng}; +use std::{ + collections::{HashMap, HashSet}, + num::NonZeroUsize, + time::Duration, + u64, +}; type TestSwarm = Swarm>; @@ -57,7 +56,9 @@ fn build_node() -> (Multiaddr, TestSwarm) { fn build_node_with_config(cfg: KademliaConfig) -> (Multiaddr, TestSwarm) { let local_key = identity::Keypair::generate_ed25519(); let local_public_key = local_key.public(); - let noise_keys = noise::Keypair::::new().into_authentic(&local_key).unwrap(); + let noise_keys = noise::Keypair::::new() + .into_authentic(&local_key) + .unwrap(); let transport = MemoryTransport::default() .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) @@ -83,25 +84,33 @@ fn build_nodes(num: usize) -> Vec<(Multiaddr, TestSwarm)> { /// Builds swarms, each listening on a port. Does *not* connect the nodes together. fn build_nodes_with_config(num: usize, cfg: KademliaConfig) -> Vec<(Multiaddr, TestSwarm)> { - (0..num).map(|_| build_node_with_config(cfg.clone())).collect() + (0..num) + .map(|_| build_node_with_config(cfg.clone())) + .collect() } fn build_connected_nodes(total: usize, step: usize) -> Vec<(Multiaddr, TestSwarm)> { build_connected_nodes_with_config(total, step, Default::default()) } -fn build_connected_nodes_with_config(total: usize, step: usize, cfg: KademliaConfig) - -> Vec<(Multiaddr, TestSwarm)> -{ +fn build_connected_nodes_with_config( + total: usize, + step: usize, + cfg: KademliaConfig, +) -> Vec<(Multiaddr, TestSwarm)> { let mut swarms = build_nodes_with_config(total, cfg); - let swarm_ids: Vec<_> = swarms.iter() + let swarm_ids: Vec<_> = swarms + .iter() .map(|(addr, swarm)| (addr.clone(), *swarm.local_peer_id())) .collect(); let mut i = 0; for (j, (addr, peer_id)) in swarm_ids.iter().enumerate().skip(1) { if i < swarm_ids.len() { - swarms[i].1.behaviour_mut().add_address(peer_id, addr.clone()); + swarms[i] + .1 + .behaviour_mut() + .add_address(peer_id, addr.clone()); } if j % step == 0 { i += step; @@ -111,11 +120,13 @@ fn build_connected_nodes_with_config(total: usize, step: usize, cfg: KademliaCon swarms } -fn build_fully_connected_nodes_with_config(total: usize, cfg: KademliaConfig) - -> Vec<(Multiaddr, TestSwarm)> -{ +fn build_fully_connected_nodes_with_config( + total: usize, + cfg: KademliaConfig, +) -> Vec<(Multiaddr, TestSwarm)> { let mut swarms = build_nodes_with_config(total, cfg); - let swarm_addr_and_peer_id: Vec<_> = swarms.iter() + let swarm_addr_and_peer_id: Vec<_> = swarms + .iter() .map(|(addr, swarm)| (addr.clone(), *swarm.local_peer_id())) .collect(); @@ -160,18 +171,12 @@ fn bootstrap() { cfg.disjoint_query_paths(true); } - let mut swarms = build_connected_nodes_with_config( - num_total, - num_group, - cfg, - ).into_iter() + let mut swarms = build_connected_nodes_with_config(num_total, num_group, cfg) + .into_iter() .map(|(_a, s)| s) .collect::>(); - let swarm_ids: Vec<_> = swarms.iter() - .map(Swarm::local_peer_id) - .cloned() - .collect(); + let swarm_ids: Vec<_> = swarms.iter().map(Swarm::local_peer_id).cloned().collect(); let qid = swarms[0].behaviour_mut().bootstrap().unwrap(); @@ -180,46 +185,49 @@ fn bootstrap() { let mut first = true; // Run test - block_on( - poll_fn(move |ctx| { - for (i, swarm) in swarms.iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::Bootstrap(Ok(ok)), .. - }))) => { - assert_eq!(id, qid); - assert_eq!(i, 0); - if first { - // Bootstrapping must start with a self-lookup. - assert_eq!(ok.peer, swarm_ids[0]); - } - first = false; - if ok.num_remaining == 0 { - assert_eq!( - swarm.behaviour_mut().queries.size(), 0, - "Expect no remaining queries when `num_remaining` is zero.", - ); - let mut known = HashSet::new(); - for b in swarm.behaviour_mut().kbuckets.iter() { - for e in b.iter() { - known.insert(e.node.key.preimage().clone()); - } + block_on(poll_fn(move |ctx| { + for (i, swarm) in swarms.iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::Bootstrap(Ok(ok)), + .. + }, + ))) => { + assert_eq!(id, qid); + assert_eq!(i, 0); + if first { + // Bootstrapping must start with a self-lookup. + assert_eq!(ok.peer, swarm_ids[0]); + } + first = false; + if ok.num_remaining == 0 { + assert_eq!( + swarm.behaviour_mut().queries.size(), + 0, + "Expect no remaining queries when `num_remaining` is zero.", + ); + let mut known = HashSet::new(); + for b in swarm.behaviour_mut().kbuckets.iter() { + for e in b.iter() { + known.insert(e.node.key.preimage().clone()); } - assert_eq!(expected_known, known); - return Poll::Ready(()) } + assert_eq!(expected_known, known); + return Poll::Ready(()); } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } - Poll::Pending - }) - ) + } + Poll::Pending + })) } QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _) @@ -228,7 +236,8 @@ fn bootstrap() { #[test] fn query_iter() { fn distances(key: &kbucket::Key, peers: Vec) -> Vec { - peers.into_iter() + peers + .into_iter() .map(kbucket::Key::from) .map(|k| k.distance(key)) .collect() @@ -236,7 +245,8 @@ fn query_iter() { fn run(rng: &mut impl Rng) { let num_total = rng.gen_range(2, 20); - let mut swarms = build_connected_nodes(num_total, 1).into_iter() + let mut swarms = build_connected_nodes(num_total, 1) + .into_iter() .map(|(_a, s)| s) .collect::>(); let swarm_ids: Vec<_> = swarms.iter().map(Swarm::local_peer_id).cloned().collect(); @@ -251,10 +261,10 @@ fn query_iter() { Some(q) => match q.info() { QueryInfo::GetClosestPeers { key } => { assert_eq!(&key[..], search_target.to_bytes().as_slice()) - }, - i => panic!("Unexpected query info: {:?}", i) - } - None => panic!("Query not found: {:?}", qid) + } + i => panic!("Unexpected query info: {:?}", i), + }, + None => panic!("Query not found: {:?}", qid), } // Set up expectations. @@ -264,37 +274,39 @@ fn query_iter() { expected_distances.sort(); // Run test - block_on( - poll_fn(move |ctx| { - for (i, swarm) in swarms.iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::GetClosestPeers(Ok(ok)), .. - }))) => { - assert_eq!(id, qid); - assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); - assert_eq!(swarm_ids[i], expected_swarm_id); - assert_eq!(swarm.behaviour_mut().queries.size(), 0); - assert!(expected_peer_ids.iter().all(|p| ok.peers.contains(p))); - let key = kbucket::Key::new(ok.key); - assert_eq!(expected_distances, distances(&key, ok.peers)); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + block_on(poll_fn(move |ctx| { + for (i, swarm) in swarms.iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::GetClosestPeers(Ok(ok)), + .. + }, + ))) => { + assert_eq!(id, qid); + assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); + assert_eq!(swarm_ids[i], expected_swarm_id); + assert_eq!(swarm.behaviour_mut().queries.size(), 0); + assert!(expected_peer_ids.iter().all(|p| ok.peers.contains(p))); + let key = kbucket::Key::new(ok.key); + assert_eq!(expected_distances, distances(&key, ok.peers)); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } - Poll::Pending - }) - ) + } + Poll::Pending + })) } let mut rng = thread_rng(); - for _ in 0 .. 10 { + for _ in 0..10 { run(&mut rng) } } @@ -304,42 +316,46 @@ fn unresponsive_not_returned_direct() { // Build one node. It contains fake addresses to non-existing nodes. We ask it to find a // random peer. We make sure that no fake address is returned. - let mut swarms = build_nodes(1).into_iter() + let mut swarms = build_nodes(1) + .into_iter() .map(|(_a, s)| s) .collect::>(); // Add fake addresses. - for _ in 0 .. 10 { - swarms[0].behaviour_mut().add_address(&PeerId::random(), Protocol::Udp(10u16).into()); + for _ in 0..10 { + swarms[0] + .behaviour_mut() + .add_address(&PeerId::random(), Protocol::Udp(10u16).into()); } // Ask first to search a random value. let search_target = PeerId::random(); swarms[0].behaviour_mut().get_closest_peers(search_target); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - result: QueryResult::GetClosestPeers(Ok(ok)), .. - }))) => { - assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); - assert_eq!(ok.peers.len(), 0); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + result: QueryResult::GetClosestPeers(Ok(ok)), + .. + }, + ))) => { + assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); + assert_eq!(ok.peers.len(), 0); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } #[test] @@ -351,96 +367,120 @@ fn unresponsive_not_returned_indirect() { let mut swarms = build_nodes(2); // Add fake addresses to first. - for _ in 0 .. 10 { - swarms[0].1.behaviour_mut().add_address(&PeerId::random(), multiaddr![Udp(10u16)]); + for _ in 0..10 { + swarms[0] + .1 + .behaviour_mut() + .add_address(&PeerId::random(), multiaddr![Udp(10u16)]); } // Connect second to first. let first_peer_id = *swarms[0].1.local_peer_id(); let first_address = swarms[0].0.clone(); - swarms[1].1.behaviour_mut().add_address(&first_peer_id, first_address); + swarms[1] + .1 + .behaviour_mut() + .add_address(&first_peer_id, first_address); // Drop the swarm addresses. - let mut swarms = swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>(); + let mut swarms = swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>(); // Ask second to search a random value. let search_target = PeerId::random(); swarms[1].behaviour_mut().get_closest_peers(search_target); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - result: QueryResult::GetClosestPeers(Ok(ok)), .. - }))) => { - assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); - assert_eq!(ok.peers.len(), 1); - assert_eq!(ok.peers[0], first_peer_id); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + result: QueryResult::GetClosestPeers(Ok(ok)), + .. + }, + ))) => { + assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); + assert_eq!(ok.peers.len(), 1); + assert_eq!(ok.peers[0], first_peer_id); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } #[test] fn get_record_not_found() { let mut swarms = build_nodes(3); - let swarm_ids: Vec<_> = swarms.iter() + let swarm_ids: Vec<_> = swarms + .iter() .map(|(_addr, swarm)| *swarm.local_peer_id()) .collect(); let (second, third) = (swarms[1].0.clone(), swarms[2].0.clone()); - swarms[0].1.behaviour_mut().add_address(&swarm_ids[1], second); - swarms[1].1.behaviour_mut().add_address(&swarm_ids[2], third); + swarms[0] + .1 + .behaviour_mut() + .add_address(&swarm_ids[1], second); + swarms[1] + .1 + .behaviour_mut() + .add_address(&swarm_ids[2], third); // Drop the swarm addresses. - let mut swarms = swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>(); + let mut swarms = swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>(); let target_key = record::Key::from(random_multihash()); - let qid = swarms[0].behaviour_mut().get_record(&target_key, Quorum::One); + let qid = swarms[0] + .behaviour_mut() + .get_record(&target_key, Quorum::One); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::GetRecord(Err(e)), .. - }))) => { - assert_eq!(id, qid); - if let GetRecordError::NotFound { key, closest_peers, } = e { - assert_eq!(key, target_key); - assert_eq!(closest_peers.len(), 2); - assert!(closest_peers.contains(&swarm_ids[1])); - assert!(closest_peers.contains(&swarm_ids[2])); - return Poll::Ready(()); - } else { - panic!("Unexpected error result: {:?}", e); - } + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::GetRecord(Err(e)), + .. + }, + ))) => { + assert_eq!(id, qid); + if let GetRecordError::NotFound { key, closest_peers } = e { + assert_eq!(key, target_key); + assert_eq!(closest_peers.len(), 2); + assert!(closest_peers.contains(&swarm_ids[1])); + assert!(closest_peers.contains(&swarm_ids[2])); + return Poll::Ready(()); + } else { + panic!("Unexpected error result: {:?}", e); } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } /// A node joining a fully connected network via three (ALPHA_VALUE) bootnodes @@ -450,7 +490,8 @@ fn get_record_not_found() { fn put_record() { fn prop(records: Vec, seed: Seed) { let mut rng = StdRng::from_seed(seed.0); - let replication_factor = NonZeroUsize::new(rng.gen_range(1, (K_VALUE.get() / 2) + 1)).unwrap(); + let replication_factor = + NonZeroUsize::new(rng.gen_range(1, (K_VALUE.get() / 2) + 1)).unwrap(); // At least 4 nodes, 1 under test + 3 bootnodes. let num_total = usize::max(4, replication_factor.get() * 2); @@ -461,10 +502,8 @@ fn put_record() { } let mut swarms = { - let mut fully_connected_swarms = build_fully_connected_nodes_with_config( - num_total - 1, - config.clone(), - ); + let mut fully_connected_swarms = + build_fully_connected_nodes_with_config(num_total - 1, config.clone()); let mut single_swarm = build_node_with_config(config); // Connect `single_swarm` to three bootnodes. @@ -479,10 +518,14 @@ fn put_record() { swarms.append(&mut fully_connected_swarms); // Drop the swarm addresses. - swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>() + swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>() }; - let records = records.into_iter() + let records = records + .into_iter() .take(num_total) .map(|mut r| { // We don't want records to expire prematurely, as they would @@ -491,12 +534,15 @@ fn put_record() { r.expires = r.expires.map(|t| t + Duration::from_secs(60)); (r.key.clone(), r) }) - .collect::>(); + .collect::>(); // Initiate put_record queries. let mut qids = HashSet::new(); for r in records.values() { - let qid = swarms[0].behaviour_mut().put_record(r.clone(), Quorum::All).unwrap(); + let qid = swarms[0] + .behaviour_mut() + .put_record(r.clone(), Quorum::All) + .unwrap(); match swarms[0].behaviour_mut().query(&qid) { Some(q) => match q.info() { QueryInfo::PutRecord { phase, record, .. } => { @@ -505,10 +551,10 @@ fn put_record() { assert_eq!(record.value, r.value); assert!(record.expires.is_some()); qids.insert(qid); - }, - i => panic!("Unexpected query info: {:?}", i) - } - None => panic!("Query not found: {:?}", qid) + } + i => panic!("Unexpected query info: {:?}", i), + }, + None => panic!("Query not found: {:?}", qid), } } @@ -517,118 +563,136 @@ fn put_record() { // The accumulated results for one round of publishing. let mut results = Vec::new(); - block_on( - poll_fn(move |ctx| loop { - // Poll all swarms until they are "Pending". - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::PutRecord(res), stats - }))) | - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::RepublishRecord(res), stats - }))) => { - assert!(qids.is_empty() || qids.remove(&id)); - assert!(stats.duration().is_some()); - assert!(stats.num_successes() >= replication_factor.get() as u32); - assert!(stats.num_requests() >= stats.num_successes()); - assert_eq!(stats.num_failures(), 0); - match res { - Err(e) => panic!("{:?}", e), - Ok(ok) => { - assert!(records.contains_key(&ok.key)); - let record = swarm.behaviour_mut().store.get(&ok.key).unwrap(); - results.push(record.into_owned()); - } + block_on(poll_fn(move |ctx| loop { + // Poll all swarms until they are "Pending". + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::PutRecord(res), + stats, + }, + ))) + | Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::RepublishRecord(res), + stats, + }, + ))) => { + assert!(qids.is_empty() || qids.remove(&id)); + assert!(stats.duration().is_some()); + assert!(stats.num_successes() >= replication_factor.get() as u32); + assert!(stats.num_requests() >= stats.num_successes()); + assert_eq!(stats.num_failures(), 0); + match res { + Err(e) => panic!("{:?}", e), + Ok(ok) => { + assert!(records.contains_key(&ok.key)); + let record = swarm.behaviour_mut().store.get(&ok.key).unwrap(); + results.push(record.into_owned()); } } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - // All swarms are Pending and not enough results have been collected - // so far, thus wait to be polled again for further progress. - if results.len() != records.len() { - return Poll::Pending - } + // All swarms are Pending and not enough results have been collected + // so far, thus wait to be polled again for further progress. + if results.len() != records.len() { + return Poll::Pending; + } - // Consume the results, checking that each record was replicated - // correctly to the closest peers to the key. - while let Some(r) = results.pop() { - let expected = records.get(&r.key).unwrap(); - - assert_eq!(r.key, expected.key); - assert_eq!(r.value, expected.value); - assert_eq!(r.expires, expected.expires); - assert_eq!(r.publisher, Some(*swarms[0].local_peer_id())); - - let key = kbucket::Key::new(r.key.clone()); - let mut expected = swarms.iter() - .skip(1) - .map(Swarm::local_peer_id) - .cloned() - .collect::>(); - expected.sort_by(|id1, id2| - kbucket::Key::from(*id1).distance(&key).cmp( - &kbucket::Key::from(*id2).distance(&key))); - - let expected = expected - .into_iter() - .take(replication_factor.get()) - .collect::>(); - - let actual = swarms.iter() - .skip(1) - .filter_map(|swarm| - if swarm.behaviour().store.get(key.preimage()).is_some() { - Some(*swarm.local_peer_id()) - } else { - None - }) - .collect::>(); - - assert_eq!(actual.len(), replication_factor.get()); - - let actual_not_expected = actual.difference(&expected) - .collect::>(); - assert!( - actual_not_expected.is_empty(), - "Did not expect records to be stored on nodes {:?}.", - actual_not_expected, - ); - - let expected_not_actual = expected.difference(&actual) - .collect::>(); - assert!(expected_not_actual.is_empty(), - "Expected record to be stored on nodes {:?}.", - expected_not_actual, - ); - } + // Consume the results, checking that each record was replicated + // correctly to the closest peers to the key. + while let Some(r) = results.pop() { + let expected = records.get(&r.key).unwrap(); + + assert_eq!(r.key, expected.key); + assert_eq!(r.value, expected.value); + assert_eq!(r.expires, expected.expires); + assert_eq!(r.publisher, Some(*swarms[0].local_peer_id())); + + let key = kbucket::Key::new(r.key.clone()); + let mut expected = swarms + .iter() + .skip(1) + .map(Swarm::local_peer_id) + .cloned() + .collect::>(); + expected.sort_by(|id1, id2| { + kbucket::Key::from(*id1) + .distance(&key) + .cmp(&kbucket::Key::from(*id2).distance(&key)) + }); + + let expected = expected + .into_iter() + .take(replication_factor.get()) + .collect::>(); + + let actual = swarms + .iter() + .skip(1) + .filter_map(|swarm| { + if swarm.behaviour().store.get(key.preimage()).is_some() { + Some(*swarm.local_peer_id()) + } else { + None + } + }) + .collect::>(); - if republished { - assert_eq!(swarms[0].behaviour_mut().store.records().count(), records.len()); - assert_eq!(swarms[0].behaviour_mut().queries.size(), 0); - for k in records.keys() { - swarms[0].behaviour_mut().store.remove(&k); - } - assert_eq!(swarms[0].behaviour_mut().store.records().count(), 0); - // All records have been republished, thus the test is complete. - return Poll::Ready(()); + assert_eq!(actual.len(), replication_factor.get()); + + let actual_not_expected = actual.difference(&expected).collect::>(); + assert!( + actual_not_expected.is_empty(), + "Did not expect records to be stored on nodes {:?}.", + actual_not_expected, + ); + + let expected_not_actual = expected.difference(&actual).collect::>(); + assert!( + expected_not_actual.is_empty(), + "Expected record to be stored on nodes {:?}.", + expected_not_actual, + ); + } + + if republished { + assert_eq!( + swarms[0].behaviour_mut().store.records().count(), + records.len() + ); + assert_eq!(swarms[0].behaviour_mut().queries.size(), 0); + for k in records.keys() { + swarms[0].behaviour_mut().store.remove(&k); } + assert_eq!(swarms[0].behaviour_mut().store.records().count(), 0); + // All records have been republished, thus the test is complete. + return Poll::Ready(()); + } - // Tell the replication job to republish asap. - swarms[0].behaviour_mut().put_record_job.as_mut().unwrap().asap(true); - republished = true; - }) - ) + // Tell the replication job to republish asap. + swarms[0] + .behaviour_mut() + .put_record_job + .as_mut() + .unwrap() + .asap(true); + republished = true; + })) } - QuickCheck::new().tests(3).quickcheck(prop as fn(_,_) -> _) + QuickCheck::new().tests(3).quickcheck(prop as fn(_, _) -> _) } #[test] @@ -637,95 +701,109 @@ fn get_record() { // Let first peer know of second peer and second peer know of third peer. for i in 0..2 { - let (peer_id, address) = (Swarm::local_peer_id(&swarms[i+1].1).clone(), swarms[i+1].0.clone()); + let (peer_id, address) = ( + Swarm::local_peer_id(&swarms[i + 1].1).clone(), + swarms[i + 1].0.clone(), + ); swarms[i].1.behaviour_mut().add_address(&peer_id, address); } // Drop the swarm addresses. - let mut swarms = swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>(); + let mut swarms = swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>(); - let record = Record::new(random_multihash(), vec![4,5,6]); + let record = Record::new(random_multihash(), vec![4, 5, 6]); let expected_cache_candidate = *Swarm::local_peer_id(&swarms[1]); swarms[2].behaviour_mut().store.put(record.clone()).unwrap(); - let qid = swarms[0].behaviour_mut().get_record(&record.key, Quorum::One); + let qid = swarms[0] + .behaviour_mut() + .get_record(&record.key, Quorum::One); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { id, - result: QueryResult::GetRecord(Ok(GetRecordOk { - records, cache_candidates - })), + result: + QueryResult::GetRecord(Ok(GetRecordOk { + records, + cache_candidates, + })), .. - }))) => { - assert_eq!(id, qid); - assert_eq!(records.len(), 1); - assert_eq!(records.first().unwrap().record, record); - assert_eq!(cache_candidates.len(), 1); - assert_eq!(cache_candidates.values().next(), Some(&expected_cache_candidate)); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + }, + ))) => { + assert_eq!(id, qid); + assert_eq!(records.len(), 1); + assert_eq!(records.first().unwrap().record, record); + assert_eq!(cache_candidates.len(), 1); + assert_eq!( + cache_candidates.values().next(), + Some(&expected_cache_candidate) + ); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } #[test] fn get_record_many() { // TODO: Randomise let num_nodes = 12; - let mut swarms = build_connected_nodes(num_nodes, 3).into_iter() + let mut swarms = build_connected_nodes(num_nodes, 3) + .into_iter() .map(|(_addr, swarm)| swarm) .collect::>(); let num_results = 10; - let record = Record::new(random_multihash(), vec![4,5,6]); + let record = Record::new(random_multihash(), vec![4, 5, 6]); - for i in 0 .. num_nodes { + for i in 0..num_nodes { swarms[i].behaviour_mut().store.put(record.clone()).unwrap(); } let quorum = Quorum::N(NonZeroUsize::new(num_results).unwrap()); let qid = swarms[0].behaviour_mut().get_record(&record.key, quorum); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { id, result: QueryResult::GetRecord(Ok(GetRecordOk { records, .. })), .. - }))) => { - assert_eq!(id, qid); - assert!(records.len() >= num_results); - assert!(records.into_iter().all(|r| r.record == record)); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + }, + ))) => { + assert_eq!(id, qid); + assert!(records.len() >= num_results); + assert!(records.into_iter().all(|r| r.record == record)); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } - Poll::Pending - }) - ) + } + Poll::Pending + })) } /// A node joining a fully connected network via three (ALPHA_VALUE) bootnodes @@ -735,7 +813,8 @@ fn get_record_many() { fn add_provider() { fn prop(keys: Vec, seed: Seed) { let mut rng = StdRng::from_seed(seed.0); - let replication_factor = NonZeroUsize::new(rng.gen_range(1, (K_VALUE.get() / 2) + 1)).unwrap(); + let replication_factor = + NonZeroUsize::new(rng.gen_range(1, (K_VALUE.get() / 2) + 1)).unwrap(); // At least 4 nodes, 1 under test + 3 bootnodes. let num_total = usize::max(4, replication_factor.get() * 2); @@ -746,10 +825,8 @@ fn add_provider() { } let mut swarms = { - let mut fully_connected_swarms = build_fully_connected_nodes_with_config( - num_total - 1, - config.clone(), - ); + let mut fully_connected_swarms = + build_fully_connected_nodes_with_config(num_total - 1, config.clone()); let mut single_swarm = build_node_with_config(config); // Connect `single_swarm` to three bootnodes. @@ -764,7 +841,10 @@ fn add_provider() { swarms.append(&mut fully_connected_swarms); // Drop addresses before returning. - swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>() + swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>() }; let keys: HashSet<_> = keys.into_iter().take(num_total).collect(); @@ -778,113 +858,136 @@ fn add_provider() { // Initiate the first round of publishing. let mut qids = HashSet::new(); for k in &keys { - let qid = swarms[0].behaviour_mut().start_providing(k.clone()).unwrap(); + let qid = swarms[0] + .behaviour_mut() + .start_providing(k.clone()) + .unwrap(); qids.insert(qid); } - block_on( - poll_fn(move |ctx| loop { - // Poll all swarms until they are "Pending". - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::StartProviding(res), .. - }))) | - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::RepublishProvider(res), .. - }))) => { - assert!(qids.is_empty() || qids.remove(&id)); - match res { - Err(e) => panic!("{:?}", e), - Ok(ok) => { - assert!(keys.contains(&ok.key)); - results.push(ok.key); - } + block_on(poll_fn(move |ctx| loop { + // Poll all swarms until they are "Pending". + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::StartProviding(res), + .. + }, + ))) + | Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::RepublishProvider(res), + .. + }, + ))) => { + assert!(qids.is_empty() || qids.remove(&id)); + match res { + Err(e) => panic!("{:?}", e), + Ok(ok) => { + assert!(keys.contains(&ok.key)); + results.push(ok.key); } } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - if results.len() == keys.len() { - // All requests have been sent for one round of publishing. - published = true - } + if results.len() == keys.len() { + // All requests have been sent for one round of publishing. + published = true + } - if !published { - // Still waiting for all requests to be sent for one round - // of publishing. - return Poll::Pending - } + if !published { + // Still waiting for all requests to be sent for one round + // of publishing. + return Poll::Pending; + } - // A round of publishing is complete. Consume the results, checking that - // each key was published to the `replication_factor` closest peers. - while let Some(key) = results.pop() { - // Collect the nodes that have a provider record for `key`. - let actual = swarms.iter().skip(1) - .filter_map(|swarm| - if swarm.behaviour().store.providers(&key).len() == 1 { - Some(Swarm::local_peer_id(&swarm).clone()) - } else { - None - }) - .collect::>(); - - if actual.len() != replication_factor.get() { - // Still waiting for some nodes to process the request. - results.push(key); - return Poll::Pending - } + // A round of publishing is complete. Consume the results, checking that + // each key was published to the `replication_factor` closest peers. + while let Some(key) = results.pop() { + // Collect the nodes that have a provider record for `key`. + let actual = swarms + .iter() + .skip(1) + .filter_map(|swarm| { + if swarm.behaviour().store.providers(&key).len() == 1 { + Some(Swarm::local_peer_id(&swarm).clone()) + } else { + None + } + }) + .collect::>(); - let mut expected = swarms.iter() - .skip(1) - .map(Swarm::local_peer_id) - .cloned() - .collect::>(); - let kbucket_key = kbucket::Key::new(key); - expected.sort_by(|id1, id2| - kbucket::Key::from(*id1).distance(&kbucket_key).cmp( - &kbucket::Key::from(*id2).distance(&kbucket_key))); - - let expected = expected - .into_iter() - .take(replication_factor.get()) - .collect::>(); - - assert_eq!(actual, expected); + if actual.len() != replication_factor.get() { + // Still waiting for some nodes to process the request. + results.push(key); + return Poll::Pending; } - // One round of publishing is complete. - assert!(results.is_empty()); - for swarm in &swarms { - assert_eq!(swarm.behaviour().queries.size(), 0); - } + let mut expected = swarms + .iter() + .skip(1) + .map(Swarm::local_peer_id) + .cloned() + .collect::>(); + let kbucket_key = kbucket::Key::new(key); + expected.sort_by(|id1, id2| { + kbucket::Key::from(*id1) + .distance(&kbucket_key) + .cmp(&kbucket::Key::from(*id2).distance(&kbucket_key)) + }); + + let expected = expected + .into_iter() + .take(replication_factor.get()) + .collect::>(); + + assert_eq!(actual, expected); + } - if republished { - assert_eq!(swarms[0].behaviour_mut().store.provided().count(), keys.len()); - for k in &keys { - swarms[0].behaviour_mut().stop_providing(&k); - } - assert_eq!(swarms[0].behaviour_mut().store.provided().count(), 0); - // All records have been republished, thus the test is complete. - return Poll::Ready(()); + // One round of publishing is complete. + assert!(results.is_empty()); + for swarm in &swarms { + assert_eq!(swarm.behaviour().queries.size(), 0); + } + + if republished { + assert_eq!( + swarms[0].behaviour_mut().store.provided().count(), + keys.len() + ); + for k in &keys { + swarms[0].behaviour_mut().stop_providing(&k); } + assert_eq!(swarms[0].behaviour_mut().store.provided().count(), 0); + // All records have been republished, thus the test is complete. + return Poll::Ready(()); + } - // Initiate the second round of publishing by telling the - // periodic provider job to run asap. - swarms[0].behaviour_mut().add_provider_job.as_mut().unwrap().asap(); - published = false; - republished = true; - }) - ) + // Initiate the second round of publishing by telling the + // periodic provider job to run asap. + swarms[0] + .behaviour_mut() + .add_provider_job + .as_mut() + .unwrap() + .asap(); + published = false; + republished = true; + })) } - QuickCheck::new().tests(3).quickcheck(prop as fn(_,_)) + QuickCheck::new().tests(3).quickcheck(prop as fn(_, _)) } /// User code should be able to start queries beyond the internal @@ -894,33 +997,32 @@ fn add_provider() { fn exceed_jobs_max_queries() { let (_addr, mut swarm) = build_node(); let num = JOBS_MAX_QUERIES + 1; - for _ in 0 .. num { + for _ in 0..num { swarm.behaviour_mut().get_closest_peers(PeerId::random()); } assert_eq!(swarm.behaviour_mut().queries.size(), num); - block_on( - poll_fn(move |ctx| { - for _ in 0 .. num { - // There are no other nodes, so the queries finish instantly. - loop { - if let Poll::Ready(Some(e)) = swarm.poll_next_unpin(ctx) { - match e { - SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - result: QueryResult::GetClosestPeers(Ok(r)), .. - }) => break assert!(r.peers.is_empty()), - SwarmEvent::Behaviour(e) => panic!("Unexpected event: {:?}", e), - _ => {} - } - } else { - panic!("Expected event") + block_on(poll_fn(move |ctx| { + for _ in 0..num { + // There are no other nodes, so the queries finish instantly. + loop { + if let Poll::Ready(Some(e)) = swarm.poll_next_unpin(ctx) { + match e { + SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { + result: QueryResult::GetClosestPeers(Ok(r)), + .. + }) => break assert!(r.peers.is_empty()), + SwarmEvent::Behaviour(e) => panic!("Unexpected event: {:?}", e), + _ => {} } + } else { + panic!("Expected event") } } - Poll::Ready(()) - }) - ) + } + Poll::Ready(()) + })) } #[test] @@ -953,11 +1055,22 @@ fn disjoint_query_does_not_finish_before_all_paths_did() { // Make `bob` and `trudy` aware of their version of the record searched by // `alice`. bob.1.behaviour_mut().store.put(record_bob.clone()).unwrap(); - trudy.1.behaviour_mut().store.put(record_trudy.clone()).unwrap(); + trudy + .1 + .behaviour_mut() + .store + .put(record_trudy.clone()) + .unwrap(); // Make `trudy` and `bob` known to `alice`. - alice.1.behaviour_mut().add_address(&trudy.1.local_peer_id(), trudy.0.clone()); - alice.1.behaviour_mut().add_address(&bob.1.local_peer_id(), bob.0.clone()); + alice + .1 + .behaviour_mut() + .add_address(&trudy.1.local_peer_id(), trudy.0.clone()); + alice + .1 + .behaviour_mut() + .add_address(&bob.1.local_peer_id(), bob.0.clone()); // Drop the swarm addresses. let (mut alice, mut bob, mut trudy) = (alice.1, bob.1, trudy.1); @@ -971,45 +1084,48 @@ fn disjoint_query_does_not_finish_before_all_paths_did() { // Poll only `alice` and `trudy` expecting `alice` not yet to return a query // result as it is not able to connect to `bob` just yet. - block_on( - poll_fn(|ctx| { - for (i, swarm) in [&mut alice, &mut trudy].iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted{ + block_on(poll_fn(|ctx| { + for (i, swarm) in [&mut alice, &mut trudy].iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { result: QueryResult::GetRecord(result), - .. - }))) => { - if i != 0 { - panic!("Expected `QueryResult` from Alice.") - } + .. + }, + ))) => { + if i != 0 { + panic!("Expected `QueryResult` from Alice.") + } - match result { - Ok(_) => panic!( - "Expected query not to finish until all \ + match result { + Ok(_) => panic!( + "Expected query not to finish until all \ disjoint paths have been explored.", - ), - Err(e) => panic!("{:?}", e), - } + ), + Err(e) => panic!("{:?}", e), } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - Poll::Ready(None) => panic!("Expected Kademlia behaviour not to finish."), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) => panic!("Expected Kademlia behaviour not to finish."), + Poll::Pending => break, } } + } - // Make sure not to wait until connections to `bob` time out. - before_timeout.poll_unpin(ctx) - }) - ); + // Make sure not to wait until connections to `bob` time out. + before_timeout.poll_unpin(ctx) + })); // Make sure `alice` has exactly one query with `trudy`'s record only. assert_eq!(1, alice.behaviour().queries.iter().count()); - alice.behaviour().queries.iter().for_each(|q| { - match &q.inner.info { - QueryInfo::GetRecord{ records, .. } => { + alice + .behaviour() + .queries + .iter() + .for_each(|q| match &q.inner.info { + QueryInfo::GetRecord { records, .. } => { assert_eq!( *records, vec![PeerRecord { @@ -1017,44 +1133,41 @@ fn disjoint_query_does_not_finish_before_all_paths_did() { record: record_trudy.clone(), }], ); - }, + } i @ _ => panic!("Unexpected query info: {:?}", i), - } - }); + }); // Poll `alice` and `bob` expecting `alice` to return a successful query // result as it is now able to explore the second disjoint path. - let records = block_on( - poll_fn(|ctx| { - for (i, swarm) in [&mut alice, &mut bob].iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted{ + let records = block_on(poll_fn(|ctx| { + for (i, swarm) in [&mut alice, &mut bob].iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { result: QueryResult::GetRecord(result), .. - }))) => { - if i != 0 { - panic!("Expected `QueryResult` from Alice.") - } + }, + ))) => { + if i != 0 { + panic!("Expected `QueryResult` from Alice.") + } - match result { - Ok(ok) => return Poll::Ready(ok.records), - Err(e) => unreachable!("{:?}", e), - } + match result { + Ok(ok) => return Poll::Ready(ok.records), + Err(e) => unreachable!("{:?}", e), } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - Poll::Ready(None) => panic!( - "Expected Kademlia behaviour not to finish.", - ), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) => panic!("Expected Kademlia behaviour not to finish.",), + Poll::Pending => break, } } + } - Poll::Pending - }) - ); + Poll::Pending + })); assert_eq!(2, records.len()); assert!(records.contains(&PeerRecord { @@ -1076,25 +1189,31 @@ fn manual_bucket_inserts() { // 1 -> 2 -> [3 -> ...] let mut swarms = build_connected_nodes_with_config(3, 1, cfg); // The peers and their addresses for which we expect `RoutablePeer` events. - let mut expected = swarms.iter().skip(2) + let mut expected = swarms + .iter() + .skip(2) .map(|(a, s)| { let pid = *Swarm::local_peer_id(s); let addr = a.clone().with(Protocol::P2p(pid.into())); (addr, pid) }) - .collect::>(); + .collect::>(); // We collect the peers for which a `RoutablePeer` event // was received in here to check at the end of the test // that none of them was inserted into a bucket. let mut routable = Vec::new(); // Start an iterative query from the first peer. - swarms[0].1.behaviour_mut().get_closest_peers(PeerId::random()); + swarms[0] + .1 + .behaviour_mut() + .get_closest_peers(PeerId::random()); block_on(poll_fn(move |ctx| { for (_, swarm) in swarms.iter_mut() { loop { match swarm.poll_next_unpin(ctx) { Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::RoutablePeer { - peer, address + peer, + address, }))) => { assert_eq!(peer, expected.remove(&address).expect("Missing address")); routable.push(peer); @@ -1103,11 +1222,11 @@ fn manual_bucket_inserts() { let bucket = swarm.behaviour_mut().kbucket(*peer).unwrap(); assert!(bucket.iter().all(|e| e.node.key.preimage() != peer)); } - return Poll::Ready(()) + return Poll::Ready(()); } } - Poll::Ready(..) => {}, - Poll::Pending => break + Poll::Ready(..) => {} + Poll::Pending => break, } } } @@ -1124,19 +1243,14 @@ fn network_behaviour_inject_address_change() { let old_address: Multiaddr = Protocol::Memory(1).into(); let new_address: Multiaddr = Protocol::Memory(2).into(); - let mut kademlia = Kademlia::new( - local_peer_id.clone(), - MemoryStore::new(local_peer_id), - ); + let mut kademlia = Kademlia::new(local_peer_id.clone(), MemoryStore::new(local_peer_id)); - let endpoint = ConnectedPoint::Dialer { address: old_address.clone() }; + let endpoint = ConnectedPoint::Dialer { + address: old_address.clone(), + }; // Mimick a connection being established. - kademlia.inject_connection_established( - &remote_peer_id, - &connection_id, - &endpoint, - ); + kademlia.inject_connection_established(&remote_peer_id, &connection_id, &endpoint); kademlia.inject_connected(&remote_peer_id); // At this point the remote is not yet known to support the @@ -1149,7 +1263,7 @@ fn network_behaviour_inject_address_change() { kademlia.inject_event( remote_peer_id.clone(), connection_id.clone(), - KademliaHandlerEvent::ProtocolConfirmed { endpoint } + KademliaHandlerEvent::ProtocolConfirmed { endpoint }, ); assert_eq!( @@ -1160,8 +1274,12 @@ fn network_behaviour_inject_address_change() { kademlia.inject_address_change( &remote_peer_id, &connection_id, - &ConnectedPoint::Dialer { address: old_address.clone() }, - &ConnectedPoint::Dialer { address: new_address.clone() }, + &ConnectedPoint::Dialer { + address: old_address.clone(), + }, + &ConnectedPoint::Dialer { + address: new_address.clone(), + }, ); assert_eq!( diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index 70b8fdd955a..3c955bb428a 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -24,23 +24,19 @@ use crate::protocol::{ }; use crate::record::{self, Record}; use futures::prelude::*; -use libp2p_swarm::{ - IntoProtocolsHandler, - KeepAlive, - NegotiatedSubstream, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr -}; use libp2p_core::{ - ConnectedPoint, - PeerId, either::EitherOutput, - upgrade::{self, InboundUpgrade, OutboundUpgrade} + upgrade::{self, InboundUpgrade, OutboundUpgrade}, + ConnectedPoint, PeerId, +}; +use libp2p_swarm::{ + IntoProtocolsHandler, KeepAlive, NegotiatedSubstream, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; use log::trace; -use std::{error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration}; +use std::{ + error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration, +}; use wasm_timer::Instant; /// A prototype from which [`KademliaHandler`]s can be constructed. @@ -51,7 +47,10 @@ pub struct KademliaHandlerProto { impl KademliaHandlerProto { pub fn new(config: KademliaHandlerConfig) -> Self { - KademliaHandlerProto { config, _type: PhantomData } + KademliaHandlerProto { + config, + _type: PhantomData, + } } } @@ -151,7 +150,11 @@ enum SubstreamState { /// Waiting for the user to send a `KademliaHandlerIn` event containing the response. InWaitingUser(UniqueConnecId, KadInStreamSink), /// Waiting to send an answer back to the remote. - InPendingSend(UniqueConnecId, KadInStreamSink, KadResponseMsg), + InPendingSend( + UniqueConnecId, + KadInStreamSink, + KadResponseMsg, + ), /// Waiting to flush an answer back to the remote. InPendingFlush(UniqueConnecId, KadInStreamSink), /// The substream is being closed. @@ -164,23 +167,28 @@ impl SubstreamState { /// If the substream is not ready to be closed, returns it back. fn try_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { match self { - SubstreamState::OutPendingOpen(_, _) - | SubstreamState::OutReportError(_, _) => Poll::Ready(()), + SubstreamState::OutPendingOpen(_, _) | SubstreamState::OutReportError(_, _) => { + Poll::Ready(()) + } SubstreamState::OutPendingSend(ref mut stream, _, _) | SubstreamState::OutPendingFlush(ref mut stream, _) | SubstreamState::OutWaitingAnswer(ref mut stream, _) - | SubstreamState::OutClosing(ref mut stream) => match Sink::poll_close(Pin::new(stream), cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - }, + | SubstreamState::OutClosing(ref mut stream) => { + match Sink::poll_close(Pin::new(stream), cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } SubstreamState::InWaitingMessage(_, ref mut stream) | SubstreamState::InWaitingUser(_, ref mut stream) | SubstreamState::InPendingSend(_, ref mut stream, _) | SubstreamState::InPendingFlush(_, ref mut stream) - | SubstreamState::InClosing(ref mut stream) => match Sink::poll_close(Pin::new(stream), cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - }, + | SubstreamState::InClosing(ref mut stream) => { + match Sink::poll_close(Pin::new(stream), cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } } } } @@ -282,7 +290,7 @@ pub enum KademliaHandlerEvent { value: Vec, /// The user data passed to the `PutValue`. user_data: TUserData, - } + }, } /// Error that can happen when requesting an RPC query. @@ -301,13 +309,16 @@ impl fmt::Display for KademliaHandlerQueryErr { match self { KademliaHandlerQueryErr::Upgrade(err) => { write!(f, "Error while performing Kademlia query: {}", err) - }, + } KademliaHandlerQueryErr::UnexpectedMessage => { - write!(f, "Remote answered our Kademlia RPC query with the wrong message type") - }, + write!( + f, + "Remote answered our Kademlia RPC query with the wrong message type" + ) + } KademliaHandlerQueryErr::Io(err) => { write!(f, "I/O error during a Kademlia RPC query: {}", err) - }, + } } } } @@ -424,7 +435,7 @@ pub enum KademliaHandlerIn { value: Vec, /// Identifier of the request that was made by the remote. request_id: KademliaRequestId, - } + }, } /// Unique identifier for a request. Must be passed back in order to answer a request from @@ -470,7 +481,8 @@ where fn listen_protocol(&self) -> SubstreamProtocol { if self.config.allow_listening { - SubstreamProtocol::new(self.config.protocol_config.clone(), ()).map_upgrade(upgrade::EitherUpgrade::A) + SubstreamProtocol::new(self.config.protocol_config.clone(), ()) + .map_upgrade(upgrade::EitherUpgrade::A) } else { SubstreamProtocol::new(upgrade::EitherUpgrade::B(upgrade::DeniedUpgrade), ()) } @@ -481,7 +493,8 @@ where protocol: >::Output, (msg, user_data): Self::OutboundOpenInfo, ) { - self.substreams.push(SubstreamState::OutPendingSend(protocol, msg, user_data)); + self.substreams + .push(SubstreamState::OutPendingSend(protocol, msg, user_data)); if let ProtocolStatus::Unconfirmed = self.protocol_status { // Upon the first successfully negotiated substream, we know that the // remote is configured with the same protocol name and we want @@ -493,7 +506,7 @@ where fn inject_fully_negotiated_inbound( &mut self, protocol: >::Output, - (): Self::InboundOpenInfo + (): Self::InboundOpenInfo, ) { // If `self.allow_listening` is false, then we produced a `DeniedUpgrade` and `protocol` // is a `Void`. @@ -505,7 +518,8 @@ where debug_assert!(self.config.allow_listening); let connec_unique_id = self.next_connec_unique_id; self.next_connec_unique_id.0 += 1; - self.substreams.push(SubstreamState::InWaitingMessage(connec_unique_id, protocol)); + self.substreams + .push(SubstreamState::InWaitingMessage(connec_unique_id, protocol)); if let ProtocolStatus::Unconfirmed = self.protocol_status { // Upon the first successfully negotiated substream, we know that the // remote is configured with the same protocol name and we want @@ -518,8 +532,9 @@ where match message { KademliaHandlerIn::Reset(request_id) => { let pos = self.substreams.iter().position(|state| match state { - SubstreamState::InWaitingUser(conn_id, _) => - conn_id == &request_id.connec_unique_id, + SubstreamState::InWaitingUser(conn_id, _) => { + conn_id == &request_id.connec_unique_id + } _ => false, }); if let Some(pos) = pos { @@ -531,15 +546,17 @@ where } KademliaHandlerIn::FindNodeReq { key, user_data } => { let msg = KadRequestMsg::FindNode { key }; - self.substreams.push(SubstreamState::OutPendingOpen(msg, Some(user_data))); + self.substreams + .push(SubstreamState::OutPendingOpen(msg, Some(user_data))); } KademliaHandlerIn::FindNodeRes { closer_peers, request_id, } => { let pos = self.substreams.iter().position(|state| match state { - SubstreamState::InWaitingUser(ref conn_id, _) => - conn_id == &request_id.connec_unique_id, + SubstreamState::InWaitingUser(ref conn_id, _) => { + conn_id == &request_id.connec_unique_id + } _ => false, }); @@ -549,9 +566,7 @@ where _ => unreachable!(), }; - let msg = KadResponseMsg::FindNode { - closer_peers, - }; + let msg = KadResponseMsg::FindNode { closer_peers }; self.substreams .push(SubstreamState::InPendingSend(conn_id, substream, msg)); } @@ -591,12 +606,13 @@ where } KademliaHandlerIn::AddProvider { key, provider } => { let msg = KadRequestMsg::AddProvider { key, provider }; - self.substreams.push(SubstreamState::OutPendingOpen(msg, None)); + self.substreams + .push(SubstreamState::OutPendingOpen(msg, None)); } KademliaHandlerIn::GetRecord { key, user_data } => { let msg = KadRequestMsg::GetValue { key }; - self.substreams.push(SubstreamState::OutPendingOpen(msg, Some(user_data))); - + self.substreams + .push(SubstreamState::OutPendingOpen(msg, Some(user_data))); } KademliaHandlerIn::PutRecord { record, user_data } => { let msg = KadRequestMsg::PutValue { record }; @@ -609,8 +625,9 @@ where request_id, } => { let pos = self.substreams.iter().position(|state| match state { - SubstreamState::InWaitingUser(ref conn_id, _) - => conn_id == &request_id.connec_unique_id, + SubstreamState::InWaitingUser(ref conn_id, _) => { + conn_id == &request_id.connec_unique_id + } _ => false, }); @@ -636,9 +653,9 @@ where let pos = self.substreams.iter().position(|state| match state { SubstreamState::InWaitingUser(ref conn_id, _) if conn_id == &request_id.connec_unique_id => - { - true - } + { + true + } _ => false, }); @@ -648,10 +665,7 @@ where _ => unreachable!(), }; - let msg = KadResponseMsg::PutValue { - key, - value, - }; + let msg = KadResponseMsg::PutValue { key, value }; self.substreams .push(SubstreamState::InPendingSend(conn_id, substream, msg)); } @@ -680,7 +694,12 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { if self.substreams.is_empty() { return Poll::Pending; @@ -690,8 +709,9 @@ where self.protocol_status = ProtocolStatus::Reported; return Poll::Ready(ProtocolsHandlerEvent::Custom( KademliaHandlerEvent::ProtocolConfirmed { - endpoint: self.endpoint.clone() - })) + endpoint: self.endpoint.clone(), + }, + )); } // We remove each element from `substreams` one by one and add them back. @@ -706,7 +726,8 @@ where } (None, Some(event), _) => { if self.substreams.is_empty() { - self.keep_alive = KeepAlive::Until(Instant::now() + self.config.idle_timeout); + self.keep_alive = + KeepAlive::Until(Instant::now() + self.config.idle_timeout); } return Poll::Ready(event); } @@ -765,36 +786,35 @@ fn advance_substream( >, >, bool, -) -{ +) { match state { SubstreamState::OutPendingOpen(msg, user_data) => { let ev = ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(upgrade, (msg, user_data)) + protocol: SubstreamProtocol::new(upgrade, (msg, user_data)), }; (None, Some(ev), false) } SubstreamState::OutPendingSend(mut substream, msg, user_data) => { match Sink::poll_ready(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => { - match Sink::start_send(Pin::new(&mut substream), msg) { - Ok(()) => ( - Some(SubstreamState::OutPendingFlush(substream, user_data)), - None, - true, - ), - Err(error) => { - let event = if let Some(user_data) = user_data { - Some(ProtocolsHandlerEvent::Custom(KademliaHandlerEvent::QueryError { + Poll::Ready(Ok(())) => match Sink::start_send(Pin::new(&mut substream), msg) { + Ok(()) => ( + Some(SubstreamState::OutPendingFlush(substream, user_data)), + None, + true, + ), + Err(error) => { + let event = if let Some(user_data) = user_data { + Some(ProtocolsHandlerEvent::Custom( + KademliaHandlerEvent::QueryError { error: KademliaHandlerQueryErr::Io(error), - user_data - })) - } else { - None - }; - - (None, event, false) - } + user_data, + }, + )) + } else { + None + }; + + (None, event, false) } }, Poll::Pending => ( @@ -804,10 +824,12 @@ fn advance_substream( ), Poll::Ready(Err(error)) => { let event = if let Some(user_data) = user_data { - Some(ProtocolsHandlerEvent::Custom(KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data - })) + Some(ProtocolsHandlerEvent::Custom( + KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }, + )) } else { None }; @@ -836,10 +858,12 @@ fn advance_substream( ), Poll::Ready(Err(error)) => { let event = if let Some(user_data) = user_data { - Some(ProtocolsHandlerEvent::Custom(KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data, - })) + Some(ProtocolsHandlerEvent::Custom( + KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }, + )) } else { None }; @@ -848,110 +872,121 @@ fn advance_substream( } } } - SubstreamState::OutWaitingAnswer(mut substream, user_data) => match Stream::poll_next(Pin::new(&mut substream), cx) { - Poll::Ready(Some(Ok(msg))) => { - let new_state = SubstreamState::OutClosing(substream); - let event = process_kad_response(msg, user_data); - ( - Some(new_state), - Some(ProtocolsHandlerEvent::Custom(event)), - true, - ) - } - Poll::Pending => ( - Some(SubstreamState::OutWaitingAnswer(substream, user_data)), - None, - false, - ), - Poll::Ready(Some(Err(error))) => { - let event = KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data, - }; - (None, Some(ProtocolsHandlerEvent::Custom(event)), false) - } - Poll::Ready(None) => { - let event = KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(io::ErrorKind::UnexpectedEof.into()), - user_data, - }; - (None, Some(ProtocolsHandlerEvent::Custom(event)), false) + SubstreamState::OutWaitingAnswer(mut substream, user_data) => { + match Stream::poll_next(Pin::new(&mut substream), cx) { + Poll::Ready(Some(Ok(msg))) => { + let new_state = SubstreamState::OutClosing(substream); + let event = process_kad_response(msg, user_data); + ( + Some(new_state), + Some(ProtocolsHandlerEvent::Custom(event)), + true, + ) + } + Poll::Pending => ( + Some(SubstreamState::OutWaitingAnswer(substream, user_data)), + None, + false, + ), + Poll::Ready(Some(Err(error))) => { + let event = KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }; + (None, Some(ProtocolsHandlerEvent::Custom(event)), false) + } + Poll::Ready(None) => { + let event = KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(io::ErrorKind::UnexpectedEof.into()), + user_data, + }; + (None, Some(ProtocolsHandlerEvent::Custom(event)), false) + } } - }, + } SubstreamState::OutReportError(error, user_data) => { let event = KademliaHandlerEvent::QueryError { error, user_data }; (None, Some(ProtocolsHandlerEvent::Custom(event)), false) } - SubstreamState::OutClosing(mut stream) => match Sink::poll_close(Pin::new(&mut stream), cx) { + SubstreamState::OutClosing(mut stream) => match Sink::poll_close(Pin::new(&mut stream), cx) + { Poll::Ready(Ok(())) => (None, None, false), Poll::Pending => (Some(SubstreamState::OutClosing(stream)), None, false), Poll::Ready(Err(_)) => (None, None, false), }, - SubstreamState::InWaitingMessage(id, mut substream) => match Stream::poll_next(Pin::new(&mut substream), cx) { - Poll::Ready(Some(Ok(msg))) => { - if let Ok(ev) = process_kad_request(msg, id) { - ( - Some(SubstreamState::InWaitingUser(id, substream)), - Some(ProtocolsHandlerEvent::Custom(ev)), - false, - ) - } else { - (Some(SubstreamState::InClosing(substream)), None, true) + SubstreamState::InWaitingMessage(id, mut substream) => { + match Stream::poll_next(Pin::new(&mut substream), cx) { + Poll::Ready(Some(Ok(msg))) => { + if let Ok(ev) = process_kad_request(msg, id) { + ( + Some(SubstreamState::InWaitingUser(id, substream)), + Some(ProtocolsHandlerEvent::Custom(ev)), + false, + ) + } else { + (Some(SubstreamState::InClosing(substream)), None, true) + } + } + Poll::Pending => ( + Some(SubstreamState::InWaitingMessage(id, substream)), + None, + false, + ), + Poll::Ready(None) => { + trace!("Inbound substream: EOF"); + (None, None, false) + } + Poll::Ready(Some(Err(e))) => { + trace!("Inbound substream error: {:?}", e); + (None, None, false) } } - Poll::Pending => ( - Some(SubstreamState::InWaitingMessage(id, substream)), - None, - false, - ), - Poll::Ready(None) => { - trace!("Inbound substream: EOF"); - (None, None, false) - } - Poll::Ready(Some(Err(e))) => { - trace!("Inbound substream error: {:?}", e); - (None, None, false) - }, - }, + } SubstreamState::InWaitingUser(id, substream) => ( Some(SubstreamState::InWaitingUser(id, substream)), None, false, ), - SubstreamState::InPendingSend(id, mut substream, msg) => match Sink::poll_ready(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => match Sink::start_send(Pin::new(&mut substream), msg) { - Ok(()) => ( - Some(SubstreamState::InPendingFlush(id, substream)), + SubstreamState::InPendingSend(id, mut substream, msg) => { + match Sink::poll_ready(Pin::new(&mut substream), cx) { + Poll::Ready(Ok(())) => match Sink::start_send(Pin::new(&mut substream), msg) { + Ok(()) => ( + Some(SubstreamState::InPendingFlush(id, substream)), + None, + true, + ), + Err(_) => (None, None, false), + }, + Poll::Pending => ( + Some(SubstreamState::InPendingSend(id, substream, msg)), + None, + false, + ), + Poll::Ready(Err(_)) => (None, None, false), + } + } + SubstreamState::InPendingFlush(id, mut substream) => { + match Sink::poll_flush(Pin::new(&mut substream), cx) { + Poll::Ready(Ok(())) => ( + Some(SubstreamState::InWaitingMessage(id, substream)), None, true, ), - Err(_) => (None, None, false), - }, - Poll::Pending => ( - Some(SubstreamState::InPendingSend(id, substream, msg)), - None, - false, - ), - Poll::Ready(Err(_)) => (None, None, false), + Poll::Pending => ( + Some(SubstreamState::InPendingFlush(id, substream)), + None, + false, + ), + Poll::Ready(Err(_)) => (None, None, false), + } + } + SubstreamState::InClosing(mut stream) => { + match Sink::poll_close(Pin::new(&mut stream), cx) { + Poll::Ready(Ok(())) => (None, None, false), + Poll::Pending => (Some(SubstreamState::InClosing(stream)), None, false), + Poll::Ready(Err(_)) => (None, None, false), + } } - SubstreamState::InPendingFlush(id, mut substream) => match Sink::poll_flush(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => ( - Some(SubstreamState::InWaitingMessage(id, substream)), - None, - true, - ), - Poll::Pending => ( - Some(SubstreamState::InPendingFlush(id, substream)), - None, - false, - ), - Poll::Ready(Err(_)) => (None, None, false), - }, - SubstreamState::InClosing(mut stream) => match Sink::poll_close(Pin::new(&mut stream), cx) { - Poll::Ready(Ok(())) => (None, None, false), - Poll::Pending => (Some(SubstreamState::InClosing(stream)), None, false), - Poll::Ready(Err(_)) => (None, None, false), - }, } } @@ -987,7 +1022,7 @@ fn process_kad_request( KadRequestMsg::PutValue { record } => Ok(KademliaHandlerEvent::PutRecord { record, request_id: KademliaRequestId { connec_unique_id }, - }) + }), } } @@ -1005,11 +1040,9 @@ fn process_kad_response( user_data, } } - KadResponseMsg::FindNode { closer_peers } => { - KademliaHandlerEvent::FindNodeRes { - closer_peers, - user_data, - } + KadResponseMsg::FindNode { closer_peers } => KademliaHandlerEvent::FindNodeRes { + closer_peers, + user_data, }, KadResponseMsg::GetProviders { closer_peers, @@ -1027,12 +1060,10 @@ fn process_kad_response( closer_peers, user_data, }, - KadResponseMsg::PutValue { key, value, .. } => { - KademliaHandlerEvent::PutRecordRes { - key, - value, - user_data, - } - } + KadResponseMsg::PutValue { key, value, .. } => KademliaHandlerEvent::PutRecordRes { + key, + value, + user_data, + }, } } diff --git a/protocols/kad/src/jobs.rs b/protocols/kad/src/jobs.rs index 8737f9ad8b9..402a797a52d 100644 --- a/protocols/kad/src/jobs.rs +++ b/protocols/kad/src/jobs.rs @@ -61,15 +61,15 @@ //! > to the size of all stored records. As a job runs, the records are moved //! > out of the job to the consumer, where they can be dropped after being sent. -use crate::record::{self, Record, ProviderRecord, store::RecordStore}; -use libp2p_core::PeerId; +use crate::record::{self, store::RecordStore, ProviderRecord, Record}; use futures::prelude::*; +use libp2p_core::PeerId; use std::collections::HashSet; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; use std::vec; -use wasm_timer::{Instant, Delay}; +use wasm_timer::{Delay, Instant}; /// The maximum number of queries towards which background jobs /// are allowed to start new queries on an invocation of @@ -110,7 +110,7 @@ impl PeriodicJob { fn is_ready(&mut self, cx: &mut Context<'_>, now: Instant) -> bool { if let PeriodicJobState::Waiting(delay, deadline) = &mut self.state { if now >= *deadline || !Future::poll(Pin::new(delay), cx).is_pending() { - return true + return true; } } false @@ -121,7 +121,7 @@ impl PeriodicJob { #[derive(Debug)] enum PeriodicJobState { Running(T), - Waiting(Delay, Instant) + Waiting(Delay, Instant), } ////////////////////////////////////////////////////////////////////////////// @@ -158,8 +158,8 @@ impl PutRecordJob { skipped: HashSet::new(), inner: PeriodicJob { interval: replicate_interval, - state: PeriodicJobState::Waiting(delay, deadline) - } + state: PeriodicJobState::Waiting(delay, deadline), + }, } } @@ -192,11 +192,12 @@ impl PutRecordJob { /// to be run. pub fn poll(&mut self, cx: &mut Context<'_>, store: &mut T, now: Instant) -> Poll where - for<'a> T: RecordStore<'a> + for<'a> T: RecordStore<'a>, { if self.inner.is_ready(cx, now) { let publish = self.next_publish.map_or(false, |t_pub| now >= t_pub); - let records = store.records() + let records = store + .records() .filter_map(|r| { let is_publisher = r.publisher.as_ref() == Some(&self.local_id); if self.skipped.contains(&r.key) || (!publish && is_publisher) { @@ -204,8 +205,9 @@ impl PutRecordJob { } else { let mut record = r.into_owned(); if publish && is_publisher { - record.expires = record.expires.or_else(|| - self.record_ttl.map(|ttl| now + ttl)); + record.expires = record + .expires + .or_else(|| self.record_ttl.map(|ttl| now + ttl)); } Some(record) } @@ -228,7 +230,7 @@ impl PutRecordJob { if r.is_expired(now) { store.remove(&r.key) } else { - return Poll::Ready(r) + return Poll::Ready(r); } } @@ -248,7 +250,7 @@ impl PutRecordJob { /// Periodic job for replicating provider records. pub struct AddProviderJob { - inner: PeriodicJob> + inner: PeriodicJob>, } impl AddProviderJob { @@ -261,8 +263,8 @@ impl AddProviderJob { state: { let deadline = now + interval; PeriodicJobState::Waiting(Delay::new_at(deadline), deadline) - } - } + }, + }, } } @@ -284,12 +286,18 @@ impl AddProviderJob { /// Must be called in the context of a task. When `NotReady` is returned, /// the current task is registered to be notified when the job is ready /// to be run. - pub fn poll(&mut self, cx: &mut Context<'_>, store: &mut T, now: Instant) -> Poll + pub fn poll( + &mut self, + cx: &mut Context<'_>, + store: &mut T, + now: Instant, + ) -> Poll where - for<'a> T: RecordStore<'a> + for<'a> T: RecordStore<'a>, { if self.inner.is_ready(cx, now) { - let records = store.provided() + let records = store + .provided() .map(|r| r.into_owned()) .collect::>() .into_iter(); @@ -301,7 +309,7 @@ impl AddProviderJob { if r.is_expired(now) { store.remove_provider(&r.key, &r.provider) } else { - return Poll::Ready(r) + return Poll::Ready(r); } } @@ -317,11 +325,11 @@ impl AddProviderJob { #[cfg(test)] mod tests { + use super::*; use crate::record::store::MemoryStore; use futures::{executor::block_on, future::poll_fn}; use quickcheck::*; use rand::Rng; - use super::*; fn rand_put_record_job() -> PutRecordJob { let mut rng = rand::thread_rng(); diff --git a/protocols/kad/src/kbucket.rs b/protocols/kad/src/kbucket.rs index ff00b0d7ed0..111407789d8 100644 --- a/protocols/kad/src/kbucket.rs +++ b/protocols/kad/src/kbucket.rs @@ -91,7 +91,7 @@ pub struct KBucketsTable { buckets: Vec>, /// The list of evicted entries that have been replaced with pending /// entries since the last call to [`KBucketsTable::take_applied_pending`]. - applied_pending: VecDeque> + applied_pending: VecDeque>, } /// A (type-safe) index into a `KBucketsTable`, i.e. a non-negative integer in the @@ -132,7 +132,7 @@ impl BucketIndex { fn rand_distance(&self, rng: &mut impl rand::Rng) -> Distance { let mut bytes = [0u8; 32]; let quot = self.0 / 8; - for i in 0 .. quot { + for i in 0..quot { bytes[31 - i] = rng.gen(); } let rem = (self.0 % 8) as u32; @@ -146,7 +146,7 @@ impl BucketIndex { impl KBucketsTable where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { /// Creates a new, empty Kademlia routing table with entries partitioned /// into buckets as per the Kademlia protocol. @@ -157,8 +157,10 @@ where pub fn new(local_key: TKey, pending_timeout: Duration) -> Self { KBucketsTable { local_key, - buckets: (0 .. NUM_BUCKETS).map(|_| KBucket::new(pending_timeout)).collect(), - applied_pending: VecDeque::new() + buckets: (0..NUM_BUCKETS) + .map(|_| KBucket::new(pending_timeout)) + .collect(), + applied_pending: VecDeque::new(), } } @@ -194,7 +196,7 @@ where } KBucketRef { index: BucketIndex(i), - bucket: b + bucket: b, } }) } @@ -236,10 +238,9 @@ where /// Returns an iterator over the keys closest to `target`, ordered by /// increasing distance. - pub fn closest_keys<'a, T>(&'a mut self, target: &'a T) - -> impl Iterator + 'a + pub fn closest_keys<'a, T>(&'a mut self, target: &'a T) -> impl Iterator + 'a where - T: Clone + AsRef + T: Clone + AsRef, { let distance = self.local_key.as_ref().distance(target); ClosestIter { @@ -248,18 +249,20 @@ where table: self, buckets_iter: ClosestBucketsIter::new(distance), fmap: |b: &KBucket| -> ArrayVec<_> { - b.iter().map(|(n,_)| n.key.clone()).collect() - } + b.iter().map(|(n, _)| n.key.clone()).collect() + }, } } /// Returns an iterator over the nodes closest to the `target` key, ordered by /// increasing distance. - pub fn closest<'a, T>(&'a mut self, target: &'a T) - -> impl Iterator> + 'a + pub fn closest<'a, T>( + &'a mut self, + target: &'a T, + ) -> impl Iterator> + 'a where T: Clone + AsRef, - TVal: Clone + TVal: Clone, { let distance = self.local_key.as_ref().distance(target); ClosestIter { @@ -268,11 +271,13 @@ where table: self, buckets_iter: ClosestBucketsIter::new(distance), fmap: |b: &KBucket<_, TVal>| -> ArrayVec<_> { - b.iter().map(|(n, status)| EntryView { - node: n.clone(), - status - }).collect() - } + b.iter() + .map(|(n, status)| EntryView { + node: n.clone(), + status, + }) + .collect() + }, } } @@ -283,14 +288,15 @@ where /// calculated by backtracking from the target towards the local key. pub fn count_nodes_between(&mut self, target: &T) -> usize where - T: AsRef + T: AsRef, { let local_key = self.local_key.clone(); let distance = target.as_ref().distance(&local_key); let mut iter = ClosestBucketsIter::new(distance).take_while(|i| i.get() != 0); if let Some(i) = iter.next() { - let num_first = self.buckets[i.get()].iter() - .filter(|(n,_)| n.key.as_ref().distance(&local_key) <= distance) + let num_first = self.buckets[i.get()] + .iter() + .filter(|(n, _)| n.key.as_ref().distance(&local_key) <= distance) .count(); let num_rest: usize = iter.map(|i| self.buckets[i.get()].num_entries()).sum(); num_first + num_rest @@ -317,7 +323,7 @@ struct ClosestIter<'a, TTarget, TKey, TVal, TMap, TOut> { iter: Option>, /// The projection function / mapping applied on each bucket as /// it is encountered, producing the next `iter`ator. - fmap: TMap + fmap: TMap, } /// An iterator over the bucket indices, in the order determined by the `Distance` of @@ -327,7 +333,7 @@ struct ClosestBucketsIter { /// The distance to the `local_key`. distance: Distance, /// The current state of the iterator. - state: ClosestBucketsIterState + state: ClosestBucketsIterState, } /// Operating states of a `ClosestBucketsIter`. @@ -348,34 +354,36 @@ enum ClosestBucketsIterState { /// `255` is reached, the iterator transitions to state `Done`. ZoomOut(BucketIndex), /// The iterator is in this state once it has visited all buckets. - Done + Done, } impl ClosestBucketsIter { fn new(distance: Distance) -> Self { let state = match BucketIndex::new(&distance) { Some(i) => ClosestBucketsIterState::Start(i), - None => ClosestBucketsIterState::Start(BucketIndex(0)) + None => ClosestBucketsIterState::Start(BucketIndex(0)), }; Self { distance, state } } fn next_in(&self, i: BucketIndex) -> Option { - (0 .. i.get()).rev().find_map(|i| + (0..i.get()).rev().find_map(|i| { if self.distance.0.bit(i) { Some(BucketIndex(i)) } else { None - }) + } + }) } fn next_out(&self, i: BucketIndex) -> Option { - (i.get() + 1 .. NUM_BUCKETS).find_map(|i| + (i.get() + 1..NUM_BUCKETS).find_map(|i| { if !self.distance.0.bit(i) { Some(BucketIndex(i)) } else { None - }) + } + }) } } @@ -388,7 +396,7 @@ impl Iterator for ClosestBucketsIter { self.state = ClosestBucketsIterState::ZoomIn(i); Some(i) } - ClosestBucketsIterState::ZoomIn(i) => + ClosestBucketsIterState::ZoomIn(i) => { if let Some(i) = self.next_in(i) { self.state = ClosestBucketsIterState::ZoomIn(i); Some(i) @@ -397,7 +405,8 @@ impl Iterator for ClosestBucketsIter { self.state = ClosestBucketsIterState::ZoomOut(i); Some(i) } - ClosestBucketsIterState::ZoomOut(i) => + } + ClosestBucketsIterState::ZoomOut(i) => { if let Some(i) = self.next_out(i) { self.state = ClosestBucketsIterState::ZoomOut(i); Some(i) @@ -405,19 +414,19 @@ impl Iterator for ClosestBucketsIter { self.state = ClosestBucketsIterState::Done; None } - ClosestBucketsIterState::Done => None + } + ClosestBucketsIterState::Done => None, } } } -impl Iterator -for ClosestIter<'_, TTarget, TKey, TVal, TMap, TOut> +impl Iterator for ClosestIter<'_, TTarget, TKey, TVal, TMap, TOut> where TTarget: AsRef, TKey: Clone + AsRef, TVal: Clone, TMap: Fn(&KBucket) -> ArrayVec<[TOut; K_VALUE.get()]>, - TOut: AsRef + TOut: AsRef, { type Item = TOut; @@ -426,8 +435,8 @@ where match &mut self.iter { Some(iter) => match iter.next() { Some(k) => return Some(k), - None => self.iter = None - } + None => self.iter = None, + }, None => { if let Some(i) = self.buckets_iter.next() { let bucket = &mut self.table.buckets[i.get()]; @@ -435,12 +444,15 @@ where self.table.applied_pending.push_back(applied) } let mut v = (self.fmap)(bucket); - v.sort_by(|a, b| - self.target.as_ref().distance(a.as_ref()) - .cmp(&self.target.as_ref().distance(b.as_ref()))); + v.sort_by(|a, b| { + self.target + .as_ref() + .distance(a.as_ref()) + .cmp(&self.target.as_ref().distance(b.as_ref())) + }); self.iter = Some(v.into_iter()); } else { - return None + return None; } } } @@ -451,13 +463,13 @@ where /// A reference to a bucket in a [`KBucketsTable`]. pub struct KBucketRef<'a, TKey, TVal> { index: BucketIndex, - bucket: &'a mut KBucket + bucket: &'a mut KBucket, } impl<'a, TKey, TVal> KBucketRef<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { /// Returns the minimum inclusive and maximum inclusive [`Distance`] for /// this bucket. @@ -497,14 +509,12 @@ where /// Returns an iterator over the entries in the bucket. pub fn iter(&'a self) -> impl Iterator> { - self.bucket.iter().map(move |(n, status)| { - EntryRefView { - node: NodeRefView { - key: &n.key, - value: &n.value - }, - status - } + self.bucket.iter().map(move |(n, status)| EntryRefView { + node: NodeRefView { + key: &n.key, + value: &n.value, + }, + status, }) } } @@ -528,14 +538,17 @@ mod tests { let ix = BucketIndex(i); let num = g.gen_range(0, usize::min(K_VALUE.get(), num_total) + 1); num_total -= num; - for _ in 0 .. num { + for _ in 0..num { let distance = ix.rand_distance(g); let key = local_key.for_distance(distance); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; let status = NodeStatus::arbitrary(g); match b.insert(node, status) { InsertResult::Inserted => {} - _ => panic!() + _ => panic!(), } } } @@ -607,7 +620,7 @@ mod tests { if let Entry::Absent(entry) = table.entry(&other_id) { match entry.insert((), NodeStatus::Connected) { InsertResult::Inserted => (), - _ => panic!() + _ => panic!(), } } else { panic!() @@ -634,7 +647,9 @@ mod tests { let mut table = KBucketsTable::<_, ()>::new(local_key, Duration::from_secs(5)); let mut count = 0; loop { - if count == 100 { break; } + if count == 100 { + break; + } let key = Key::from(PeerId::random()); if let Entry::Absent(e) = table.entry(&key) { match e.insert((), NodeStatus::Connected) { @@ -646,12 +661,13 @@ mod tests { } } - let mut expected_keys: Vec<_> = table.buckets + let mut expected_keys: Vec<_> = table + .buckets .iter() - .flat_map(|t| t.iter().map(|(n,_)| n.key.clone())) + .flat_map(|t| t.iter().map(|(n, _)| n.key.clone())) .collect(); - for _ in 0 .. 10 { + for _ in 0..10 { let target_key = Key::from(PeerId::random()); let keys = table.closest_keys(&target_key).collect::>(); // The list of keys is expected to match the result of a full-table scan. @@ -675,18 +691,24 @@ mod tests { match e.insert((), NodeStatus::Connected) { InsertResult::Pending { disconnected } => { expected_applied = AppliedPending { - inserted: Node { key: key.clone(), value: () }, - evicted: Some(Node { key: disconnected, value: () }) + inserted: Node { + key: key.clone(), + value: (), + }, + evicted: Some(Node { + key: disconnected, + value: (), + }), }; full_bucket_index = BucketIndex::new(&key.distance(&local_key)); - break - }, - _ => panic!() + break; + } + _ => panic!(), } } else { panic!() } - }, + } _ => continue, } } else { @@ -701,12 +723,12 @@ mod tests { match table.entry(&expected_applied.inserted.key) { Entry::Present(_, NodeStatus::Connected) => {} - x => panic!("Unexpected entry: {:?}", x) + x => panic!("Unexpected entry: {:?}", x), } match table.entry(&expected_applied.evicted.as_ref().unwrap().key) { Entry::Absent(_) => {} - x => panic!("Unexpected entry: {:?}", x) + x => panic!("Unexpected entry: {:?}", x), } assert_eq!(Some(expected_applied), table.take_applied_pending()); @@ -734,6 +756,8 @@ mod tests { }) } - QuickCheck::new().tests(10).quickcheck(prop as fn(_,_) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _) -> _) } } diff --git a/protocols/kad/src/kbucket/bucket.rs b/protocols/kad/src/kbucket/bucket.rs index e9729917e8f..b9d34519d5d 100644 --- a/protocols/kad/src/kbucket/bucket.rs +++ b/protocols/kad/src/kbucket/bucket.rs @@ -25,8 +25,8 @@ //! > buckets in a `KBucketsTable` and hence is enforced by the public API //! > of the `KBucketsTable` and in particular the public `Entry` API. -pub use crate::K_VALUE; use super::*; +pub use crate::K_VALUE; /// A `PendingNode` is a `Node` that is pending insertion into a `KBucket`. #[derive(Debug, Clone)] @@ -51,7 +51,7 @@ pub enum NodeStatus { /// The node is considered connected. Connected, /// The node is considered disconnected. - Disconnected + Disconnected, } impl PendingNode { @@ -125,29 +125,29 @@ pub struct KBucket { /// The timeout window before a new pending node is eligible for insertion, /// if the least-recently connected node is not updated as being connected /// in the meantime. - pending_timeout: Duration + pending_timeout: Duration, } /// The result of inserting an entry into a bucket. #[must_use] #[derive(Debug, Clone, PartialEq, Eq)] pub enum InsertResult { - /// The entry has been successfully inserted. - Inserted, - /// The entry is pending insertion because the relevant bucket is currently full. - /// The entry is inserted after a timeout elapsed, if the status of the - /// least-recently connected (and currently disconnected) node in the bucket - /// is not updated before the timeout expires. - Pending { - /// The key of the least-recently connected entry that is currently considered - /// disconnected and whose corresponding peer should be checked for connectivity - /// in order to prevent it from being evicted. If connectivity to the peer is - /// re-established, the corresponding entry should be updated with - /// [`NodeStatus::Connected`]. - disconnected: TKey - }, - /// The entry was not inserted because the relevant bucket is full. - Full + /// The entry has been successfully inserted. + Inserted, + /// The entry is pending insertion because the relevant bucket is currently full. + /// The entry is inserted after a timeout elapsed, if the status of the + /// least-recently connected (and currently disconnected) node in the bucket + /// is not updated before the timeout expires. + Pending { + /// The key of the least-recently connected entry that is currently considered + /// disconnected and whose corresponding peer should be checked for connectivity + /// in order to prevent it from being evicted. If connectivity to the peer is + /// re-established, the corresponding entry should be updated with + /// [`NodeStatus::Connected`]. + disconnected: TKey, + }, + /// The entry was not inserted because the relevant bucket is full. + Full, } /// The result of applying a pending node to a bucket, possibly @@ -158,13 +158,13 @@ pub struct AppliedPending { pub inserted: Node, /// The node that has been evicted from the bucket to make room for the /// pending node, if any. - pub evicted: Option> + pub evicted: Option>, } impl KBucket where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { /// Creates a new `KBucket` with the given timeout for pending entries. pub fn new(pending_timeout: Duration) -> Self { @@ -189,7 +189,8 @@ where /// Returns a reference to the pending node of the bucket, if there is any /// with a matching key. pub fn as_pending(&self, key: &TKey) -> Option<&PendingNode> { - self.pending().filter(|p| p.node.key.as_ref() == key.as_ref()) + self.pending() + .filter(|p| p.node.key.as_ref() == key.as_ref()) } /// Returns a reference to a node in the bucket. @@ -199,7 +200,10 @@ where /// Returns an iterator over the nodes in the bucket, together with their status. pub fn iter(&self) -> impl Iterator, NodeStatus)> { - self.nodes.iter().enumerate().map(move |(p, n)| (n, self.status(Position(p)))) + self.nodes + .iter() + .enumerate() + .map(move |(p, n)| (n, self.status(Position(p)))) } /// Inserts the pending node into the bucket, if its timeout has elapsed, @@ -214,21 +218,20 @@ where if self.nodes.is_full() { if self.status(Position(0)) == NodeStatus::Connected { // The bucket is full with connected nodes. Drop the pending node. - return None + return None; } debug_assert!(self.first_connected_pos.map_or(true, |p| p > 0)); // (*) - // The pending node will be inserted. + // The pending node will be inserted. let inserted = pending.node.clone(); // A connected pending node goes at the end of the list for // the connected peers, removing the least-recently connected. if pending.status == NodeStatus::Connected { let evicted = Some(self.nodes.remove(0)); - self.first_connected_pos = self.first_connected_pos - .map_or_else( - | | Some(self.nodes.len()), - |p| p.checked_sub(1)); + self.first_connected_pos = self + .first_connected_pos + .map_or_else(|| Some(self.nodes.len()), |p| p.checked_sub(1)); self.nodes.push(pending.node); - return Some(AppliedPending { inserted, evicted }) + return Some(AppliedPending { inserted, evicted }); } // A disconnected pending node goes at the end of the list // for the disconnected peers. @@ -236,21 +239,25 @@ where let insert_pos = p.checked_sub(1).expect("by (*)"); let evicted = Some(self.nodes.remove(0)); self.nodes.insert(insert_pos, pending.node); - return Some(AppliedPending { inserted, evicted }) + return Some(AppliedPending { inserted, evicted }); } else { // All nodes are disconnected. Insert the new node as the most // recently disconnected, removing the least-recently disconnected. let evicted = Some(self.nodes.remove(0)); self.nodes.push(pending.node); - return Some(AppliedPending { inserted, evicted }) + return Some(AppliedPending { inserted, evicted }); } } else { // There is room in the bucket, so just insert the pending node. let inserted = pending.node.clone(); match self.insert(pending.node, pending.status) { - InsertResult::Inserted => - return Some(AppliedPending { inserted, evicted: None }), - _ => unreachable!("Bucket is not full.") + InsertResult::Inserted => { + return Some(AppliedPending { + inserted, + evicted: None, + }) + } + _ => unreachable!("Bucket is not full."), } } } else { @@ -289,8 +296,8 @@ where } // Reinsert the node with the desired status. match self.insert(node, status) { - InsertResult::Inserted => {}, - _ => unreachable!("The node is removed before being (re)inserted.") + InsertResult::Inserted => {} + _ => unreachable!("The node is removed before being (re)inserted."), } } } @@ -317,7 +324,7 @@ where NodeStatus::Connected => { if self.nodes.is_full() { if self.first_connected_pos == Some(0) || self.pending.is_some() { - return InsertResult::Full + return InsertResult::Full; } else { self.pending = Some(PendingNode { node, @@ -325,8 +332,8 @@ where replace: Instant::now() + self.pending_timeout, }); return InsertResult::Pending { - disconnected: self.nodes[0].key.clone() - } + disconnected: self.nodes[0].key.clone(), + }; } } let pos = self.nodes.len(); @@ -336,7 +343,7 @@ where } NodeStatus::Disconnected => { if self.nodes.is_full() { - return InsertResult::Full + return InsertResult::Full; } if let Some(ref mut p) = self.first_connected_pos { self.nodes.insert(*p, node); @@ -357,17 +364,19 @@ where let node = self.nodes.remove(pos.0); // Adjust `first_connected_pos` accordingly. match status { - NodeStatus::Connected => + NodeStatus::Connected => { if self.first_connected_pos.map_or(false, |p| p == pos.0) { if pos.0 == self.nodes.len() { // It was the last connected node. self.first_connected_pos = None } } - NodeStatus::Disconnected => + } + NodeStatus::Disconnected => { if let Some(ref mut p) = self.first_connected_pos { *p -= 1; } + } } Some((node, status, pos)) } else { @@ -406,7 +415,10 @@ where /// Gets the position of an node in the bucket. pub fn position(&self, key: &TKey) -> Option { - self.nodes.iter().position(|p| p.key.as_ref() == key.as_ref()).map(Position) + self.nodes + .iter() + .position(|p| p.key.as_ref() == key.as_ref()) + .map(Position) } /// Gets a mutable reference to the node identified by the given key. @@ -414,30 +426,35 @@ where /// Returns `None` if the given key does not refer to a node in the /// bucket. pub fn get_mut(&mut self, key: &TKey) -> Option<&mut Node> { - self.nodes.iter_mut().find(move |p| p.key.as_ref() == key.as_ref()) + self.nodes + .iter_mut() + .find(move |p| p.key.as_ref() == key.as_ref()) } } #[cfg(test)] mod tests { + use super::*; use libp2p_core::PeerId; + use quickcheck::*; use rand::Rng; use std::collections::VecDeque; - use super::*; - use quickcheck::*; impl Arbitrary for KBucket, ()> { fn arbitrary(g: &mut G) -> KBucket, ()> { let timeout = Duration::from_secs(g.gen_range(1, g.size() as u64)); let mut bucket = KBucket::, ()>::new(timeout); let num_nodes = g.gen_range(1, K_VALUE.get() + 1); - for _ in 0 .. num_nodes { + for _ in 0..num_nodes { let key = Key::from(PeerId::random()); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; let status = NodeStatus::arbitrary(g); match bucket.insert(node, status) { InsertResult::Inserted => {} - _ => panic!() + _ => panic!(), } } bucket @@ -463,7 +480,7 @@ mod tests { // Fill a bucket with random nodes with the given status. fn fill_bucket(bucket: &mut KBucket, ()>, status: NodeStatus) { let num_entries_start = bucket.num_entries(); - for i in 0 .. K_VALUE.get() - num_entries_start { + for i in 0..K_VALUE.get() - num_entries_start { let key = Key::from(PeerId::random()); let node = Node { key, value: () }; assert_eq!(InsertResult::Inserted, bucket.insert(node, status)); @@ -483,13 +500,16 @@ mod tests { // Fill the bucket, thereby populating the expected lists in insertion order. for status in status { let key = Key::from(PeerId::random()); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; let full = bucket.num_entries() == K_VALUE.get(); match bucket.insert(node, status) { InsertResult::Inserted => { let vec = match status { NodeStatus::Connected => &mut connected, - NodeStatus::Disconnected => &mut disconnected + NodeStatus::Disconnected => &mut disconnected, }; if full { vec.pop_front(); @@ -501,21 +521,20 @@ mod tests { } // Get all nodes from the bucket, together with their status. - let mut nodes = bucket.iter() + let mut nodes = bucket + .iter() .map(|(n, s)| (s, n.key.clone())) .collect::>(); // Split the list of nodes at the first connected node. - let first_connected_pos = nodes.iter().position(|(s,_)| *s == NodeStatus::Connected); + let first_connected_pos = nodes.iter().position(|(s, _)| *s == NodeStatus::Connected); assert_eq!(bucket.first_connected_pos, first_connected_pos); let tail = first_connected_pos.map_or(Vec::new(), |p| nodes.split_off(p)); // All nodes before the first connected node must be disconnected and // in insertion order. Similarly, all remaining nodes must be connected // and in insertion order. - nodes == Vec::from(disconnected) - && - tail == Vec::from(connected) + nodes == Vec::from(disconnected) && tail == Vec::from(connected) } quickcheck(prop as fn(_) -> _); @@ -532,12 +551,12 @@ mod tests { let key = Key::from(PeerId::random()); let node = Node { key, value: () }; match bucket.insert(node, NodeStatus::Disconnected) { - InsertResult::Full => {}, - x => panic!("{:?}", x) + InsertResult::Full => {} + x => panic!("{:?}", x), } // One-by-one fill the bucket with connected nodes, replacing the disconnected ones. - for i in 0 .. K_VALUE.get() { + for i in 0..K_VALUE.get() { let (first, first_status) = bucket.iter().next().unwrap(); let first_disconnected = first.clone(); assert_eq!(first_status, NodeStatus::Disconnected); @@ -545,17 +564,21 @@ mod tests { // Add a connected node, which is expected to be pending, scheduled to // replace the first (i.e. least-recently connected) node. let key = Key::from(PeerId::random()); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; match bucket.insert(node.clone(), NodeStatus::Connected) { - InsertResult::Pending { disconnected } => - assert_eq!(disconnected, first_disconnected.key), - x => panic!("{:?}", x) + InsertResult::Pending { disconnected } => { + assert_eq!(disconnected, first_disconnected.key) + } + x => panic!("{:?}", x), } // Trying to insert another connected node fails. match bucket.insert(node.clone(), NodeStatus::Connected) { - InsertResult::Full => {}, - x => panic!("{:?}", x) + InsertResult::Full => {} + x => panic!("{:?}", x), } assert!(bucket.pending().is_some()); @@ -564,10 +587,13 @@ mod tests { let pending = bucket.pending_mut().expect("No pending node."); pending.set_ready_at(Instant::now() - Duration::from_secs(1)); let result = bucket.apply_pending(); - assert_eq!(result, Some(AppliedPending { - inserted: node.clone(), - evicted: Some(first_disconnected) - })); + assert_eq!( + result, + Some(AppliedPending { + inserted: node.clone(), + evicted: Some(first_disconnected) + }) + ); assert_eq!(Some((&node, NodeStatus::Connected)), bucket.iter().last()); assert!(bucket.pending().is_none()); assert_eq!(Some(K_VALUE.get() - (i + 1)), bucket.first_connected_pos); @@ -580,8 +606,8 @@ mod tests { let key = Key::from(PeerId::random()); let node = Node { key, value: () }; match bucket.insert(node, NodeStatus::Connected) { - InsertResult::Full => {}, - x => panic!("{:?}", x) + InsertResult::Full => {} + x => panic!("{:?}", x), } } @@ -594,7 +620,10 @@ mod tests { // Add a connected pending node. let key = Key::from(PeerId::random()); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; if let InsertResult::Pending { disconnected } = bucket.insert(node, NodeStatus::Connected) { assert_eq!(&disconnected, &first_disconnected.key); } else { @@ -607,16 +636,21 @@ mod tests { // The pending node has been discarded. assert!(bucket.pending().is_none()); - assert!(bucket.iter().all(|(n,_)| &n.key != &key)); + assert!(bucket.iter().all(|(n, _)| &n.key != &key)); // The initially disconnected node is now the most-recently connected. - assert_eq!(Some((&first_disconnected, NodeStatus::Connected)), bucket.iter().last()); - assert_eq!(bucket.position(&first_disconnected.key).map(|p| p.0), bucket.first_connected_pos); + assert_eq!( + Some((&first_disconnected, NodeStatus::Connected)), + bucket.iter().last() + ); + assert_eq!( + bucket.position(&first_disconnected.key).map(|p| p.0), + bucket.first_connected_pos + ); assert_eq!(1, bucket.num_connected()); assert_eq!(K_VALUE.get() - 1, bucket.num_disconnected()); } - #[test] fn bucket_update() { fn prop(mut bucket: KBucket, ()>, pos: Position, status: NodeStatus) -> bool { @@ -627,7 +661,10 @@ mod tests { let key = bucket.nodes[pos].key.clone(); // Record the (ordered) list of status of all nodes in the bucket. - let mut expected = bucket.iter().map(|(n,s)| (n.key.clone(), s)).collect::>(); + let mut expected = bucket + .iter() + .map(|(n, s)| (n.key.clone(), s)) + .collect::>(); // Update the node in the bucket. bucket.update(&key, status); @@ -636,14 +673,17 @@ mod tests { // preserving the status and relative order of all other nodes. let expected_pos = match status { NodeStatus::Connected => num_nodes - 1, - NodeStatus::Disconnected => bucket.first_connected_pos.unwrap_or(num_nodes) - 1 + NodeStatus::Disconnected => bucket.first_connected_pos.unwrap_or(num_nodes) - 1, }; expected.remove(pos); expected.insert(expected_pos, (key.clone(), status)); - let actual = bucket.iter().map(|(n,s)| (n.key.clone(), s)).collect::>(); + let actual = bucket + .iter() + .map(|(n, s)| (n.key.clone(), s)) + .collect::>(); expected == actual } - quickcheck(prop as fn(_,_,_) -> _); + quickcheck(prop as fn(_, _, _) -> _); } } diff --git a/protocols/kad/src/kbucket/entry.rs b/protocols/kad/src/kbucket/entry.rs index e72140cec73..3447146007b 100644 --- a/protocols/kad/src/kbucket/entry.rs +++ b/protocols/kad/src/kbucket/entry.rs @@ -21,7 +21,7 @@ //! The `Entry` API for quering and modifying the entries of a `KBucketsTable` //! representing the nodes participating in the Kademlia DHT. -pub use super::bucket::{Node, NodeStatus, InsertResult, AppliedPending, K_VALUE}; +pub use super::bucket::{AppliedPending, InsertResult, Node, NodeStatus, K_VALUE}; pub use super::key::*; use super::*; @@ -31,27 +31,27 @@ pub struct EntryRefView<'a, TPeerId, TVal> { /// The node represented by the entry. pub node: NodeRefView<'a, TPeerId, TVal>, /// The status of the node identified by the key. - pub status: NodeStatus + pub status: NodeStatus, } /// An immutable by-reference view of a `Node`. pub struct NodeRefView<'a, TKey, TVal> { pub key: &'a TKey, - pub value: &'a TVal + pub value: &'a TVal, } impl EntryRefView<'_, TKey, TVal> { pub fn to_owned(&self) -> EntryView where TKey: Clone, - TVal: Clone + TVal: Clone, { EntryView { node: Node { key: self.node.key.clone(), - value: self.node.value.clone() + value: self.node.value.clone(), }, - status: self.status + status: self.status, } } } @@ -63,7 +63,7 @@ pub struct EntryView { /// The node represented by the entry. pub node: Node, /// The status of the node. - pub status: NodeStatus + pub status: NodeStatus, } impl, TVal> AsRef for EntryView { @@ -96,7 +96,7 @@ struct EntryRef<'a, TKey, TVal> { impl<'a, TKey, TVal> Entry<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { /// Creates a new `Entry` for a `Key`, encapsulating access to a bucket. pub(super) fn new(bucket: &'a mut KBucket, key: &'a TKey) -> Self { @@ -120,18 +120,18 @@ where Entry::Present(entry, status) => Some(EntryRefView { node: NodeRefView { key: entry.0.key, - value: entry.value() + value: entry.value(), }, - status: *status + status: *status, }), Entry::Pending(entry, status) => Some(EntryRefView { node: NodeRefView { key: entry.0.key, - value: entry.value() + value: entry.value(), }, - status: *status + status: *status, }), - _ => None + _ => None, } } @@ -170,7 +170,7 @@ pub struct PresentEntry<'a, TKey, TVal>(EntryRef<'a, TKey, TVal>); impl<'a, TKey, TVal> PresentEntry<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { fn new(bucket: &'a mut KBucket, key: &'a TKey) -> Self { PresentEntry(EntryRef { bucket, key }) @@ -183,7 +183,9 @@ where /// Returns the value associated with the key. pub fn value(&mut self) -> &mut TVal { - &mut self.0.bucket + &mut self + .0 + .bucket .get_mut(self.0.key) .expect("We can only build a PresentEntry if the entry is in the bucket; QED") .value @@ -196,7 +198,9 @@ where /// Removes the entry from the bucket. pub fn remove(self) -> EntryView { - let (node, status, _pos) = self.0.bucket + let (node, status, _pos) = self + .0 + .bucket .remove(&self.0.key) .expect("We can only build a PresentEntry if the entry is in the bucket; QED"); EntryView { node, status } @@ -210,7 +214,7 @@ pub struct PendingEntry<'a, TKey, TVal>(EntryRef<'a, TKey, TVal>); impl<'a, TKey, TVal> PendingEntry<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { fn new(bucket: &'a mut KBucket, key: &'a TKey) -> Self { PendingEntry(EntryRef { bucket, key }) @@ -223,7 +227,8 @@ where /// Returns the value associated with the key. pub fn value(&mut self) -> &mut TVal { - self.0.bucket + self.0 + .bucket .pending_mut() .expect("We can only build a ConnectedPendingEntry if the entry is pending; QED") .value_mut() @@ -237,10 +242,10 @@ where /// Removes the pending entry from the bucket. pub fn remove(self) -> EntryView { - let pending = self.0.bucket - .remove_pending() - .expect("We can only build a PendingEntry if the entry is pending insertion - into the bucket; QED"); + let pending = self.0.bucket.remove_pending().expect( + "We can only build a PendingEntry if the entry is pending insertion + into the bucket; QED", + ); let status = pending.status(); let node = pending.into_node(); EntryView { node, status } @@ -254,7 +259,7 @@ pub struct AbsentEntry<'a, TKey, TVal>(EntryRef<'a, TKey, TVal>); impl<'a, TKey, TVal> AbsentEntry<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { fn new(bucket: &'a mut KBucket, key: &'a TKey) -> Self { AbsentEntry(EntryRef { bucket, key }) @@ -267,9 +272,12 @@ where /// Attempts to insert the entry into a bucket. pub fn insert(self, value: TVal, status: NodeStatus) -> InsertResult { - self.0.bucket.insert(Node { - key: self.0.key.clone(), - value - }, status) + self.0.bucket.insert( + Node { + key: self.0.key.clone(), + value, + }, + status, + ) } } diff --git a/protocols/kad/src/kbucket/key.rs b/protocols/kad/src/kbucket/key.rs index 38eb825ae66..ca3444da636 100644 --- a/protocols/kad/src/kbucket/key.rs +++ b/protocols/kad/src/kbucket/key.rs @@ -18,13 +18,13 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use uint::*; -use libp2p_core::{PeerId, multihash::Multihash}; +use crate::record; +use libp2p_core::{multihash::Multihash, PeerId}; +use sha2::digest::generic_array::{typenum::U32, GenericArray}; use sha2::{Digest, Sha256}; -use sha2::digest::generic_array::{GenericArray, typenum::U32}; use std::borrow::Borrow; use std::hash::{Hash, Hasher}; -use crate::record; +use uint::*; construct_uint! { /// 256-bit unsigned integer. @@ -52,7 +52,7 @@ impl Key { /// [`Key::into_preimage`]. pub fn new(preimage: T) -> Key where - T: Borrow<[u8]> + T: Borrow<[u8]>, { let bytes = KeyBytes::new(preimage.borrow()); Key { preimage, bytes } @@ -71,7 +71,7 @@ impl Key { /// Computes the distance of the keys according to the XOR metric. pub fn distance(&self, other: &U) -> Distance where - U: AsRef + U: AsRef, { self.bytes.distance(other) } @@ -93,22 +93,16 @@ impl Into for Key { } impl From for Key { - fn from(m: Multihash) -> Self { - let bytes = KeyBytes(Sha256::digest(&m.to_bytes())); - Key { - preimage: m, - bytes - } - } + fn from(m: Multihash) -> Self { + let bytes = KeyBytes(Sha256::digest(&m.to_bytes())); + Key { preimage: m, bytes } + } } impl From for Key { fn from(p: PeerId) -> Self { - let bytes = KeyBytes(Sha256::digest(&p.to_bytes())); - Key { - preimage: p, - bytes - } + let bytes = KeyBytes(Sha256::digest(&p.to_bytes())); + Key { preimage: p, bytes } } } @@ -153,7 +147,7 @@ impl KeyBytes { /// value through a random oracle. pub fn new(value: T) -> Self where - T: Borrow<[u8]> + T: Borrow<[u8]>, { KeyBytes(Sha256::digest(value.borrow())) } @@ -161,7 +155,7 @@ impl KeyBytes { /// Computes the distance of the keys according to the XOR metric. pub fn distance(&self, other: &U) -> Distance where - U: AsRef + U: AsRef, { let a = U256::from(self.0.as_slice()); let b = U256::from(other.as_ref().0.as_slice()); @@ -201,8 +195,8 @@ impl Distance { #[cfg(test)] mod tests { use super::*; - use quickcheck::*; use libp2p_core::multihash::Code; + use quickcheck::*; use rand::Rng; impl Arbitrary for Key { @@ -231,7 +225,7 @@ mod tests { fn prop(a: Key, b: Key) -> bool { a.distance(&b) == b.distance(&a) } - quickcheck(prop as fn(_,_) -> _) + quickcheck(prop as fn(_, _) -> _) } #[test] @@ -246,18 +240,18 @@ mod tests { TestResult::from_bool(a.distance(&c) <= Distance(ab_plus_bc)) } } - quickcheck(prop as fn(_,_,_) -> _) + quickcheck(prop as fn(_, _, _) -> _) } #[test] fn unidirectionality() { fn prop(a: Key, b: Key) -> bool { let d = a.distance(&b); - (0 .. 100).all(|_| { + (0..100).all(|_| { let c = Key::from(PeerId::random()); a.distance(&c) != d || b == c }) } - quickcheck(prop as fn(_,_) -> _) + quickcheck(prop as fn(_, _) -> _) } } diff --git a/protocols/kad/src/lib.rs b/protocols/kad/src/lib.rs index 30819ec0056..0fbeb61587d 100644 --- a/protocols/kad/src/lib.rs +++ b/protocols/kad/src/lib.rs @@ -40,56 +40,19 @@ mod dht_proto { pub use addresses::Addresses; pub use behaviour::{ - Kademlia, - KademliaBucketInserts, - KademliaConfig, - KademliaCaching, - KademliaEvent, - Quorum + AddProviderContext, AddProviderError, AddProviderOk, AddProviderPhase, AddProviderResult, + BootstrapError, BootstrapOk, BootstrapResult, GetClosestPeersError, GetClosestPeersOk, + GetClosestPeersResult, GetProvidersError, GetProvidersOk, GetProvidersResult, GetRecordError, + GetRecordOk, GetRecordResult, InboundRequest, PeerRecord, PutRecordContext, PutRecordError, + PutRecordOk, PutRecordPhase, PutRecordResult, QueryInfo, QueryMut, QueryRef, QueryResult, + QueryStats, }; pub use behaviour::{ - InboundRequest, - - QueryRef, - QueryMut, - - QueryResult, - QueryInfo, - QueryStats, - - PeerRecord, - - BootstrapResult, - BootstrapOk, - BootstrapError, - - GetRecordResult, - GetRecordOk, - GetRecordError, - - PutRecordPhase, - PutRecordContext, - PutRecordResult, - PutRecordOk, - PutRecordError, - - GetClosestPeersResult, - GetClosestPeersOk, - GetClosestPeersError, - - AddProviderPhase, - AddProviderContext, - AddProviderResult, - AddProviderOk, - AddProviderError, - - GetProvidersResult, - GetProvidersOk, - GetProvidersError, + Kademlia, KademliaBucketInserts, KademliaCaching, KademliaConfig, KademliaEvent, Quorum, }; -pub use query::QueryId; pub use protocol::KadConnectionType; -pub use record::{store, Record, ProviderRecord}; +pub use query::QueryId; +pub use record::{store, ProviderRecord, Record}; use std::num::NonZeroUsize; diff --git a/protocols/kad/src/protocol.rs b/protocols/kad/src/protocol.rs index 393074e2932..0f883649b05 100644 --- a/protocols/kad/src/protocol.rs +++ b/protocols/kad/src/protocol.rs @@ -26,14 +26,14 @@ //! to poll the underlying transport for incoming messages, and the `Sink` component //! is used to send messages to remote peers. -use bytes::BytesMut; -use codec::UviBytes; use crate::dht_proto as proto; use crate::record::{self, Record}; -use futures::prelude::*; use asynchronous_codec::Framed; -use libp2p_core::{Multiaddr, PeerId}; +use bytes::BytesMut; +use codec::UviBytes; +use futures::prelude::*; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; +use libp2p_core::{Multiaddr, PeerId}; use prost::Message; use std::{borrow::Cow, convert::TryFrom, time::Duration}; use std::{io, iter}; @@ -101,8 +101,7 @@ impl TryFrom for KadPeer { fn try_from(peer: proto::message::Peer) -> Result { // TODO: this is in fact a CID; not sure if this should be handled in `from_bytes` or // as a special case here - let node_id = PeerId::from_bytes(&peer.id) - .map_err(|_| invalid_data("invalid peer id"))?; + let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?; let mut addrs = Vec::with_capacity(peer.addrs.len()); for addr in peer.addrs.into_iter() { @@ -118,7 +117,7 @@ impl TryFrom for KadPeer { Ok(KadPeer { node_id, multiaddrs: addrs, - connection_ty + connection_ty, }) } } @@ -131,7 +130,7 @@ impl From for proto::message::Peer { connection: { let ct: proto::message::ConnectionType = peer.connection_ty.into(); ct as i32 - } + }, } } } @@ -202,13 +201,15 @@ where .with::<_, _, fn(_) -> _, _>(|response| { let proto_struct = resp_msg_to_proto(response); let mut buf = Vec::with_capacity(proto_struct.encoded_len()); - proto_struct.encode(&mut buf).expect("Vec provides capacity as needed"); + proto_struct + .encode(&mut buf) + .expect("Vec provides capacity as needed"); future::ready(Ok(io::Cursor::new(buf))) }) .and_then::<_, fn(_) -> _>(|bytes| { let request = match proto::Message::decode(bytes) { Ok(r) => r, - Err(err) => return future::ready(Err(err.into())) + Err(err) => return future::ready(Err(err.into())), }; future::ready(proto_to_req_msg(request)) }), @@ -234,13 +235,15 @@ where .with::<_, _, fn(_) -> _, _>(|request| { let proto_struct = req_msg_to_proto(request); let mut buf = Vec::with_capacity(proto_struct.encoded_len()); - proto_struct.encode(&mut buf).expect("Vec provides capacity as needed"); + proto_struct + .encode(&mut buf) + .expect("Vec provides capacity as needed"); future::ready(Ok(io::Cursor::new(buf))) }) .and_then::<_, fn(_) -> _>(|bytes| { let response = match proto::Message::decode(bytes) { Ok(r) => r, - Err(err) => return future::ready(Err(err.into())) + Err(err) => return future::ready(Err(err.into())), }; future::ready(proto_to_resp_msg(response)) }), @@ -301,9 +304,7 @@ pub enum KadRequestMsg { }, /// Request to put a value into the dht records. - PutValue { - record: Record, - } + PutValue { record: Record }, } /// Response that we can send to a peer or that we received from a peer. @@ -348,38 +349,38 @@ fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message { match kad_msg { KadRequestMsg::Ping => proto::Message { r#type: proto::message::MessageType::Ping as i32, - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::FindNode { key } => proto::Message { r#type: proto::message::MessageType::FindNode as i32, key, cluster_level_raw: 10, - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::GetProviders { key } => proto::Message { r#type: proto::message::MessageType::GetProviders as i32, key: key.to_vec(), cluster_level_raw: 10, - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::AddProvider { key, provider } => proto::Message { r#type: proto::message::MessageType::AddProvider as i32, cluster_level_raw: 10, key: key.to_vec(), provider_peers: vec![provider.into()], - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::GetValue { key } => proto::Message { r#type: proto::message::MessageType::GetValue as i32, cluster_level_raw: 10, key: key.to_vec(), - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::PutValue { record } => proto::Message { r#type: proto::message::MessageType::PutValue as i32, record: Some(record_to_proto(record)), - .. proto::Message::default() - } + ..proto::Message::default() + }, } } @@ -388,27 +389,33 @@ fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message { match kad_msg { KadResponseMsg::Pong => proto::Message { r#type: proto::message::MessageType::Ping as i32, - .. proto::Message::default() + ..proto::Message::default() }, KadResponseMsg::FindNode { closer_peers } => proto::Message { r#type: proto::message::MessageType::FindNode as i32, cluster_level_raw: 9, closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(), - .. proto::Message::default() + ..proto::Message::default() }, - KadResponseMsg::GetProviders { closer_peers, provider_peers } => proto::Message { + KadResponseMsg::GetProviders { + closer_peers, + provider_peers, + } => proto::Message { r#type: proto::message::MessageType::GetProviders as i32, cluster_level_raw: 9, closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(), provider_peers: provider_peers.into_iter().map(KadPeer::into).collect(), - .. proto::Message::default() + ..proto::Message::default() }, - KadResponseMsg::GetValue { record, closer_peers } => proto::Message { + KadResponseMsg::GetValue { + record, + closer_peers, + } => proto::Message { r#type: proto::message::MessageType::GetValue as i32, cluster_level_raw: 9, closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(), record: record.map(record_to_proto), - .. proto::Message::default() + ..proto::Message::default() }, KadResponseMsg::PutValue { key, value } => proto::Message { r#type: proto::message::MessageType::PutValue as i32, @@ -416,10 +423,10 @@ fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message { record: Some(proto::Record { key: key.to_vec(), value, - .. proto::Record::default() + ..proto::Record::default() }), - .. proto::Message::default() - } + ..proto::Message::default() + }, } } @@ -436,20 +443,19 @@ fn proto_to_req_msg(message: proto::Message) -> Result let record = record_from_proto(message.record.unwrap_or_default())?; Ok(KadRequestMsg::PutValue { record }) } - proto::message::MessageType::GetValue => { - Ok(KadRequestMsg::GetValue { key: record::Key::from(message.key) }) - } - proto::message::MessageType::FindNode => { - Ok(KadRequestMsg::FindNode { key: message.key }) - } - proto::message::MessageType::GetProviders => { - Ok(KadRequestMsg::GetProviders { key: record::Key::from(message.key)}) - } + proto::message::MessageType::GetValue => Ok(KadRequestMsg::GetValue { + key: record::Key::from(message.key), + }), + proto::message::MessageType::FindNode => Ok(KadRequestMsg::FindNode { key: message.key }), + proto::message::MessageType::GetProviders => Ok(KadRequestMsg::GetProviders { + key: record::Key::from(message.key), + }), proto::message::MessageType::AddProvider => { // TODO: for now we don't parse the peer properly, so it is possible that we get // parsing errors for peers even when they are valid; we ignore these // errors for now, but ultimately we should just error altogether - let provider = message.provider_peers + let provider = message + .provider_peers .into_iter() .find_map(|peer| KadPeer::try_from(peer).ok()); @@ -473,22 +479,28 @@ fn proto_to_resp_msg(message: proto::Message) -> Result Ok(KadResponseMsg::Pong), proto::message::MessageType::GetValue => { - let record = - if let Some(r) = message.record { - Some(record_from_proto(r)?) - } else { - None - }; + let record = if let Some(r) = message.record { + Some(record_from_proto(r)?) + } else { + None + }; - let closer_peers = message.closer_peers.into_iter() + let closer_peers = message + .closer_peers + .into_iter() .filter_map(|peer| KadPeer::try_from(peer).ok()) .collect(); - Ok(KadResponseMsg::GetValue { record, closer_peers }) + Ok(KadResponseMsg::GetValue { + record, + closer_peers, + }) } proto::message::MessageType::FindNode => { - let closer_peers = message.closer_peers.into_iter() + let closer_peers = message + .closer_peers + .into_iter() .filter_map(|peer| KadPeer::try_from(peer).ok()) .collect(); @@ -496,11 +508,15 @@ fn proto_to_resp_msg(message: proto::Message) -> Result { - let closer_peers = message.closer_peers.into_iter() + let closer_peers = message + .closer_peers + .into_iter() .filter_map(|peer| KadPeer::try_from(peer).ok()) .collect(); - let provider_peers = message.provider_peers.into_iter() + let provider_peers = message + .provider_peers + .into_iter() .filter_map(|peer| KadPeer::try_from(peer).ok()) .collect(); @@ -512,18 +528,19 @@ fn proto_to_resp_msg(message: proto::Message) -> Result { let key = record::Key::from(message.key); - let rec = message.record.ok_or_else(|| { - invalid_data("received PutValue message with no record") - })?; + let rec = message + .record + .ok_or_else(|| invalid_data("received PutValue message with no record"))?; Ok(KadResponseMsg::PutValue { key, - value: rec.value + value: rec.value, }) } - proto::message::MessageType::AddProvider => + proto::message::MessageType::AddProvider => { Err(invalid_data("received an unexpected AddProvider message")) + } } } @@ -531,23 +548,26 @@ fn record_from_proto(record: proto::Record) -> Result { let key = record::Key::from(record.key); let value = record.value; - let publisher = - if !record.publisher.is_empty() { - PeerId::from_bytes(&record.publisher) - .map(Some) - .map_err(|_| invalid_data("Invalid publisher peer ID."))? - } else { - None - }; - - let expires = - if record.ttl > 0 { - Some(Instant::now() + Duration::from_secs(record.ttl as u64)) - } else { - None - }; - - Ok(Record { key, value, publisher, expires }) + let publisher = if !record.publisher.is_empty() { + PeerId::from_bytes(&record.publisher) + .map(Some) + .map_err(|_| invalid_data("Invalid publisher peer ID."))? + } else { + None + }; + + let expires = if record.ttl > 0 { + Some(Instant::now() + Duration::from_secs(record.ttl as u64)) + } else { + None + }; + + Ok(Record { + key, + value, + publisher, + expires, + }) } fn record_to_proto(record: Record) -> proto::Record { @@ -555,7 +575,8 @@ fn record_to_proto(record: Record) -> proto::Record { key: record.key.to_vec(), value: record.value, publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(), - ttl: record.expires + ttl: record + .expires .map(|t| { let now = Instant::now(); if t > now { @@ -565,14 +586,14 @@ fn record_to_proto(record: Record) -> proto::Record { } }) .unwrap_or(0), - time_received: String::new() + time_received: String::new(), } } /// Creates an `io::Error` with `io::ErrorKind::InvalidData`. fn invalid_data(e: E) -> io::Error where - E: Into> + E: Into>, { io::Error::new(io::ErrorKind::InvalidData, e) } diff --git a/protocols/kad/src/query.rs b/protocols/kad/src/query.rs index 0b19425b7fe..6fcf90df79f 100644 --- a/protocols/kad/src/query.rs +++ b/protocols/kad/src/query.rs @@ -20,16 +20,18 @@ mod peers; -use peers::PeersIterState; -use peers::closest::{ClosestPeersIterConfig, ClosestPeersIter, disjoint::ClosestDisjointPeersIter}; +use peers::closest::{ + disjoint::ClosestDisjointPeersIter, ClosestPeersIter, ClosestPeersIterConfig, +}; use peers::fixed::FixedPeersIter; +use peers::PeersIterState; -use crate::{ALPHA_VALUE, K_VALUE}; use crate::kbucket::{Key, KeyBytes}; +use crate::{ALPHA_VALUE, K_VALUE}; use either::Either; use fnv::FnvHashMap; use libp2p_core::PeerId; -use std::{time::Duration, num::NonZeroUsize}; +use std::{num::NonZeroUsize, time::Duration}; use wasm_timer::Instant; /// A `QueryPool` provides an aggregate state machine for driving `Query`s to completion. @@ -53,7 +55,7 @@ pub enum QueryPoolState<'a, TInner> { /// A query has finished. Finished(Query), /// A query has timed out. - Timeout(Query) + Timeout(Query), } impl QueryPool { @@ -62,7 +64,7 @@ impl QueryPool { QueryPool { next_id: 0, config, - queries: Default::default() + queries: Default::default(), } } @@ -89,7 +91,7 @@ impl QueryPool { /// Adds a query to the pool that contacts a fixed set of peers. pub fn add_fixed(&mut self, peers: I, inner: TInner) -> QueryId where - I: IntoIterator + I: IntoIterator, { let id = self.next_query_id(); self.continue_fixed(id, peers, inner); @@ -101,7 +103,7 @@ impl QueryPool { /// earlier. pub fn continue_fixed(&mut self, id: QueryId, peers: I, inner: TInner) where - I: IntoIterator + I: IntoIterator, { assert!(!self.queries.contains_key(&id)); let parallelism = self.config.replication_factor; @@ -114,7 +116,7 @@ impl QueryPool { pub fn add_iter_closest(&mut self, target: T, peers: I, inner: TInner) -> QueryId where T: Into + Clone, - I: IntoIterator> + I: IntoIterator>, { let id = self.next_query_id(); self.continue_iter_closest(id, target, peers, inner); @@ -125,18 +127,18 @@ impl QueryPool { pub fn continue_iter_closest(&mut self, id: QueryId, target: T, peers: I, inner: TInner) where T: Into + Clone, - I: IntoIterator> + I: IntoIterator>, { let cfg = ClosestPeersIterConfig { num_results: self.config.replication_factor, parallelism: self.config.parallelism, - .. ClosestPeersIterConfig::default() + ..ClosestPeersIterConfig::default() }; let peer_iter = if self.config.disjoint_query_paths { - QueryPeerIter::ClosestDisjoint( - ClosestDisjointPeersIter::with_config(cfg, target, peers), - ) + QueryPeerIter::ClosestDisjoint(ClosestDisjointPeersIter::with_config( + cfg, target, peers, + )) } else { QueryPeerIter::Closest(ClosestPeersIter::with_config(cfg, target, peers)) }; @@ -172,18 +174,18 @@ impl QueryPool { match query.next(now) { PeersIterState::Finished => { finished = Some(query_id); - break + break; } PeersIterState::Waiting(Some(peer_id)) => { let peer = peer_id.into_owned(); waiting = Some((query_id, peer)); - break + break; } PeersIterState::Waiting(None) | PeersIterState::WaitingAtCapacity => { let elapsed = now - query.stats.start.unwrap_or(now); if elapsed >= self.config.timeout { timeout = Some(query_id); - break + break; } } } @@ -191,19 +193,19 @@ impl QueryPool { if let Some((query_id, peer_id)) = waiting { let query = self.queries.get_mut(&query_id).expect("s.a."); - return QueryPoolState::Waiting(Some((query, peer_id))) + return QueryPoolState::Waiting(Some((query, peer_id))); } if let Some(query_id) = finished { let mut query = self.queries.remove(&query_id).expect("s.a."); query.stats.end = Some(now); - return QueryPoolState::Finished(query) + return QueryPoolState::Finished(query); } if let Some(query_id) = timeout { let mut query = self.queries.remove(&query_id).expect("s.a."); query.stats.end = Some(now); - return QueryPoolState::Timeout(query) + return QueryPoolState::Timeout(query); } if self.queries.is_empty() { @@ -269,13 +271,18 @@ pub struct Query { enum QueryPeerIter { Closest(ClosestPeersIter), ClosestDisjoint(ClosestDisjointPeersIter), - Fixed(FixedPeersIter) + Fixed(FixedPeersIter), } impl Query { /// Creates a new query without starting it. fn new(id: QueryId, peer_iter: QueryPeerIter, inner: TInner) -> Self { - Query { id, inner, peer_iter, stats: QueryStats::empty() } + Query { + id, + inner, + peer_iter, + stats: QueryStats::empty(), + } } /// Gets the unique ID of the query. @@ -293,7 +300,7 @@ impl Query { let updated = match &mut self.peer_iter { QueryPeerIter::Closest(iter) => iter.on_failure(peer), QueryPeerIter::ClosestDisjoint(iter) => iter.on_failure(peer), - QueryPeerIter::Fixed(iter) => iter.on_failure(peer) + QueryPeerIter::Fixed(iter) => iter.on_failure(peer), }; if updated { self.stats.failure += 1; @@ -305,12 +312,12 @@ impl Query { /// the query, if applicable. pub fn on_success(&mut self, peer: &PeerId, new_peers: I) where - I: IntoIterator + I: IntoIterator, { let updated = match &mut self.peer_iter { QueryPeerIter::Closest(iter) => iter.on_success(peer, new_peers), QueryPeerIter::ClosestDisjoint(iter) => iter.on_success(peer, new_peers), - QueryPeerIter::Fixed(iter) => iter.on_success(peer) + QueryPeerIter::Fixed(iter) => iter.on_success(peer), }; if updated { self.stats.success += 1; @@ -322,7 +329,7 @@ impl Query { match &self.peer_iter { QueryPeerIter::Closest(iter) => iter.is_waiting(peer), QueryPeerIter::ClosestDisjoint(iter) => iter.is_waiting(peer), - QueryPeerIter::Fixed(iter) => iter.is_waiting(peer) + QueryPeerIter::Fixed(iter) => iter.is_waiting(peer), } } @@ -331,7 +338,7 @@ impl Query { let state = match &mut self.peer_iter { QueryPeerIter::Closest(iter) => iter.next(now), QueryPeerIter::ClosestDisjoint(iter) => iter.next(now), - QueryPeerIter::Fixed(iter) => iter.next() + QueryPeerIter::Fixed(iter) => iter.next(), }; if let PeersIterState::Waiting(Some(_)) = state { @@ -360,12 +367,18 @@ impl Query { /// [`QueryPoolState::Finished`]. pub fn try_finish<'a, I>(&mut self, peers: I) -> bool where - I: IntoIterator + I: IntoIterator, { match &mut self.peer_iter { - QueryPeerIter::Closest(iter) => { iter.finish(); true }, + QueryPeerIter::Closest(iter) => { + iter.finish(); + true + } QueryPeerIter::ClosestDisjoint(iter) => iter.finish_paths(peers), - QueryPeerIter::Fixed(iter) => { iter.finish(); true } + QueryPeerIter::Fixed(iter) => { + iter.finish(); + true + } } } @@ -377,7 +390,7 @@ impl Query { match &mut self.peer_iter { QueryPeerIter::Closest(iter) => iter.finish(), QueryPeerIter::ClosestDisjoint(iter) => iter.finish(), - QueryPeerIter::Fixed(iter) => iter.finish() + QueryPeerIter::Fixed(iter) => iter.finish(), } } @@ -389,7 +402,7 @@ impl Query { match &self.peer_iter { QueryPeerIter::Closest(iter) => iter.is_finished(), QueryPeerIter::ClosestDisjoint(iter) => iter.is_finished(), - QueryPeerIter::Fixed(iter) => iter.is_finished() + QueryPeerIter::Fixed(iter) => iter.is_finished(), } } @@ -398,9 +411,13 @@ impl Query { let peers = match self.peer_iter { QueryPeerIter::Closest(iter) => Either::Left(Either::Left(iter.into_result())), QueryPeerIter::ClosestDisjoint(iter) => Either::Left(Either::Right(iter.into_result())), - QueryPeerIter::Fixed(iter) => Either::Right(iter.into_result()) + QueryPeerIter::Fixed(iter) => Either::Right(iter.into_result()), }; - QueryResult { peers, inner: self.inner, stats: self.stats } + QueryResult { + peers, + inner: self.inner, + stats: self.stats, + } } } @@ -411,7 +428,7 @@ pub struct QueryResult { /// The successfully contacted peers. pub peers: TPeers, /// The collected query statistics. - pub stats: QueryStats + pub stats: QueryStats, } /// Execution statistics of a query. @@ -421,7 +438,7 @@ pub struct QueryStats { success: u32, failure: u32, start: Option, - end: Option + end: Option, } impl QueryStats { @@ -490,9 +507,9 @@ impl QueryStats { failure: self.failure + other.failure, start: match (self.start, other.start) { (Some(a), Some(b)) => Some(std::cmp::min(a, b)), - (a, b) => a.or(b) + (a, b) => a.or(b), }, - end: std::cmp::max(self.end, other.end) + end: std::cmp::max(self.end, other.end), } } } diff --git a/protocols/kad/src/query/peers.rs b/protocols/kad/src/query/peers.rs index 964068aa25a..7a177a494cf 100644 --- a/protocols/kad/src/query/peers.rs +++ b/protocols/kad/src/query/peers.rs @@ -63,5 +63,5 @@ pub enum PeersIterState<'a> { WaitingAtCapacity, /// The iterator finished. - Finished + Finished, } diff --git a/protocols/kad/src/query/peers/closest.rs b/protocols/kad/src/query/peers/closest.rs index 702335c50f8..684c109b934 100644 --- a/protocols/kad/src/query/peers/closest.rs +++ b/protocols/kad/src/query/peers/closest.rs @@ -20,11 +20,11 @@ use super::*; -use crate::{K_VALUE, ALPHA_VALUE}; -use crate::kbucket::{Key, KeyBytes, Distance}; +use crate::kbucket::{Distance, Key, KeyBytes}; +use crate::{ALPHA_VALUE, K_VALUE}; use libp2p_core::PeerId; -use std::{time::Duration, iter::FromIterator, num::NonZeroUsize}; use std::collections::btree_map::{BTreeMap, Entry}; +use std::{iter::FromIterator, num::NonZeroUsize, time::Duration}; use wasm_timer::Instant; pub mod disjoint; @@ -88,16 +88,24 @@ impl ClosestPeersIter { /// Creates a new iterator with a default configuration. pub fn new(target: KeyBytes, known_closest_peers: I) -> Self where - I: IntoIterator> + I: IntoIterator>, { - Self::with_config(ClosestPeersIterConfig::default(), target, known_closest_peers) + Self::with_config( + ClosestPeersIterConfig::default(), + target, + known_closest_peers, + ) } /// Creates a new iterator with the given configuration. - pub fn with_config(config: ClosestPeersIterConfig, target: T, known_closest_peers: I) -> Self + pub fn with_config( + config: ClosestPeersIterConfig, + target: T, + known_closest_peers: I, + ) -> Self where I: IntoIterator>, - T: Into + T: Into, { let target = target.into(); @@ -110,17 +118,18 @@ impl ClosestPeersIter { let state = PeerState::NotContacted; (distance, Peer { key, state }) }) - .take(K_VALUE.into())); + .take(K_VALUE.into()), + ); // The iterator initially makes progress by iterating towards the target. - let state = State::Iterating { no_progress : 0 }; + let state = State::Iterating { no_progress: 0 }; ClosestPeersIter { config, target, state, closest_peers, - num_waiting: 0 + num_waiting: 0, } } @@ -142,10 +151,10 @@ impl ClosestPeersIter { /// calling this function has no effect and `false` is returned. pub fn on_success(&mut self, peer: &PeerId, closer_peers: I) -> bool where - I: IntoIterator + I: IntoIterator, { if let State::Finished = self.state { - return false + return false; } let key = Key::from(*peer); @@ -163,10 +172,8 @@ impl ClosestPeersIter { PeerState::Unresponsive => { e.get_mut().state = PeerState::Succeeded; } - PeerState::NotContacted - | PeerState::Failed - | PeerState::Succeeded => return false - } + PeerState::NotContacted | PeerState::Failed | PeerState::Succeeded => return false, + }, } let num_closest = self.closest_peers.len(); @@ -176,7 +183,10 @@ impl ClosestPeersIter { for peer in closer_peers { let key = peer.into(); let distance = self.target.distance(&key); - let peer = Peer { key, state: PeerState::NotContacted }; + let peer = Peer { + key, + state: PeerState::NotContacted, + }; self.closest_peers.entry(distance).or_insert(peer); // The iterator makes progress if the new peer is either closer to the target // than any peer seen so far (i.e. is the first entry), or the iterator did @@ -195,13 +205,14 @@ impl ClosestPeersIter { State::Iterating { no_progress } } } - State::Stalled => + State::Stalled => { if progress { State::Iterating { no_progress: 0 } } else { State::Stalled } - State::Finished => State::Finished + } + State::Finished => State::Finished, }; true @@ -219,7 +230,7 @@ impl ClosestPeersIter { /// calling this function has no effect and `false` is returned. pub fn on_failure(&mut self, peer: &PeerId) -> bool { if let State::Finished = self.state { - return false + return false; } let key = Key::from(*peer); @@ -233,13 +244,9 @@ impl ClosestPeersIter { self.num_waiting -= 1; e.get_mut().state = PeerState::Failed } - PeerState::Unresponsive => { - e.get_mut().state = PeerState::Failed - } - PeerState::NotContacted - | PeerState::Failed - | PeerState::Succeeded => return false - } + PeerState::Unresponsive => e.get_mut().state = PeerState::Failed, + PeerState::NotContacted | PeerState::Failed | PeerState::Succeeded => return false, + }, } true @@ -248,10 +255,11 @@ impl ClosestPeersIter { /// Returns the list of peers for which the iterator is currently waiting /// for results. pub fn waiting(&self) -> impl Iterator { - self.closest_peers.values().filter_map(|peer| - match peer.state { + self.closest_peers + .values() + .filter_map(|peer| match peer.state { PeerState::Waiting(..) => Some(peer.key.preimage()), - _ => None + _ => None, }) } @@ -269,7 +277,7 @@ impl ClosestPeersIter { /// Advances the state of the iterator, potentially getting a new peer to contact. pub fn next(&mut self, now: Instant) -> PeersIterState<'_> { if let State::Finished = self.state { - return PeersIterState::Finished + return PeersIterState::Finished; } // Count the number of peers that returned a result. If there is a @@ -292,13 +300,11 @@ impl ClosestPeersIter { debug_assert!(self.num_waiting > 0); self.num_waiting -= 1; peer.state = PeerState::Unresponsive - } - else if at_capacity { + } else if at_capacity { // The iterator is still waiting for a result from a peer and is // at capacity w.r.t. the maximum number of peers being waited on. - return PeersIterState::WaitingAtCapacity - } - else { + return PeersIterState::WaitingAtCapacity; + } else { // The iterator is still waiting for a result from a peer and the // `result_counter` did not yet reach `num_results`. Therefore // the iterator is not yet done, regardless of already successful @@ -307,26 +313,28 @@ impl ClosestPeersIter { } } - PeerState::Succeeded => + PeerState::Succeeded => { if let Some(ref mut cnt) = result_counter { *cnt += 1; // If `num_results` successful results have been delivered for the // closest peers, the iterator is done. if *cnt >= self.config.num_results.get() { self.state = State::Finished; - return PeersIterState::Finished + return PeersIterState::Finished; } } + } - PeerState::NotContacted => + PeerState::NotContacted => { if !at_capacity { let timeout = now + self.config.peer_timeout; peer.state = PeerState::Waiting(timeout); self.num_waiting += 1; - return PeersIterState::Waiting(Some(Cow::Borrowed(peer.key.preimage()))) + return PeersIterState::Waiting(Some(Cow::Borrowed(peer.key.preimage()))); } else { - return PeersIterState::WaitingAtCapacity + return PeersIterState::WaitingAtCapacity; } + } PeerState::Unresponsive | PeerState::Failed => { // Skip over unresponsive or failed peers. @@ -379,11 +387,12 @@ impl ClosestPeersIter { /// k closest nodes it has not already queried". fn at_capacity(&self) -> bool { match self.state { - State::Stalled => self.num_waiting >= usize::max( - self.config.num_results.get(), self.config.parallelism.get() - ), + State::Stalled => { + self.num_waiting + >= usize::max(self.config.num_results.get(), self.config.parallelism.get()) + } State::Iterating { .. } => self.num_waiting >= self.config.parallelism.get(), - State::Finished => true + State::Finished => true, } } } @@ -425,14 +434,14 @@ enum State { /// from the closest peers (not counting those that failed or are unresponsive) /// or because the iterator ran out of peers that have not yet delivered /// results (or failed). - Finished + Finished, } /// Representation of a peer in the context of a iterator. #[derive(Debug, Clone)] struct Peer { key: Key, - state: PeerState + state: PeerState, } /// The state of a single `Peer`. @@ -466,19 +475,29 @@ enum PeerState { #[cfg(test)] mod tests { use super::*; - use libp2p_core::{PeerId, multihash::{Code, Multihash}}; + use libp2p_core::{ + multihash::{Code, Multihash}, + PeerId, + }; use quickcheck::*; - use rand::{Rng, rngs::StdRng, SeedableRng}; + use rand::{rngs::StdRng, Rng, SeedableRng}; use std::{iter, time::Duration}; fn random_peers(n: usize, g: &mut R) -> Vec { - (0 .. n).map(|_| PeerId::from_multihash( - Multihash::wrap(Code::Sha2_256.into(), &g.gen::<[u8; 32]>()).unwrap() - ).unwrap()).collect() + (0..n) + .map(|_| { + PeerId::from_multihash( + Multihash::wrap(Code::Sha2_256.into(), &g.gen::<[u8; 32]>()).unwrap(), + ) + .unwrap() + }) + .collect() } fn sorted>(target: &T, peers: &Vec>) -> bool { - peers.windows(2).all(|w| w[0].distance(&target) < w[1].distance(&target)) + peers + .windows(2) + .all(|w| w[0].distance(&target) < w[1].distance(&target)) } impl Arbitrary for ClosestPeersIter { @@ -510,26 +529,32 @@ mod tests { fn prop(iter: ClosestPeersIter) { let target = iter.target.clone(); - let (keys, states): (Vec<_>, Vec<_>) = iter.closest_peers + let (keys, states): (Vec<_>, Vec<_>) = iter + .closest_peers .values() .map(|e| (e.key.clone(), &e.state)) .unzip(); - let none_contacted = states - .iter() - .all(|s| match s { - PeerState::NotContacted => true, - _ => false - }); - - assert!(none_contacted, - "Unexpected peer state in new iterator."); - assert!(sorted(&target, &keys), - "Closest peers in new iterator not sorted by distance to target."); - assert_eq!(iter.num_waiting(), 0, - "Unexpected peers in progress in new iterator."); - assert_eq!(iter.into_result().count(), 0, - "Unexpected closest peers in new iterator"); + let none_contacted = states.iter().all(|s| match s { + PeerState::NotContacted => true, + _ => false, + }); + + assert!(none_contacted, "Unexpected peer state in new iterator."); + assert!( + sorted(&target, &keys), + "Closest peers in new iterator not sorted by distance to target." + ); + assert_eq!( + iter.num_waiting(), + 0, + "Unexpected peers in progress in new iterator." + ); + assert_eq!( + iter.into_result().count(), + 0, + "Unexpected closest peers in new iterator" + ); } QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _) @@ -541,7 +566,8 @@ mod tests { let now = Instant::now(); let mut rng = StdRng::from_seed(seed.0); - let mut expected = iter.closest_peers + let mut expected = iter + .closest_peers .values() .map(|e| e.key.clone()) .collect::>(); @@ -559,8 +585,7 @@ mod tests { // Split off the next up to `parallelism` expected peers. else if expected.len() < max_parallelism { remaining = Vec::new(); - } - else { + } else { remaining = expected.split_off(max_parallelism); } @@ -570,7 +595,9 @@ mod tests { PeersIterState::Finished => break 'finished, PeersIterState::Waiting(Some(p)) => assert_eq!(&*p, k.preimage()), PeersIterState::Waiting(None) => panic!("Expected another peer."), - PeersIterState::WaitingAtCapacity => panic!("Unexpectedly reached capacity.") + PeersIterState::WaitingAtCapacity => { + panic!("Unexpectedly reached capacity.") + } } } let num_waiting = iter.num_waiting(); @@ -611,7 +638,7 @@ mod tests { // of results. let all_contacted = iter.closest_peers.values().all(|e| match e.state { PeerState::NotContacted | PeerState::Waiting { .. } => false, - _ => true + _ => true, }); let target = iter.target.clone(); @@ -634,7 +661,9 @@ mod tests { } } - QuickCheck::new().tests(10).quickcheck(prop as fn(_, _) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _) -> _) } #[test] @@ -648,7 +677,7 @@ mod tests { // A first peer reports a "closer" peer. let peer1 = match iter.next(now) { PeersIterState::Waiting(Some(p)) => p.into_owned(), - _ => panic!("No peer.") + _ => panic!("No peer."), }; iter.on_success(&peer1, closer.clone()); // Duplicate result from te same peer. @@ -665,25 +694,38 @@ mod tests { }; // The "closer" peer must only be in the iterator once. - let n = iter.closest_peers.values().filter(|e| e.key.preimage() == &closer[0]).count(); + let n = iter + .closest_peers + .values() + .filter(|e| e.key.preimage() == &closer[0]) + .count(); assert_eq!(n, 1); true } - QuickCheck::new().tests(10).quickcheck(prop as fn(_, _) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _) -> _) } #[test] fn timeout() { fn prop(mut iter: ClosestPeersIter) -> bool { let mut now = Instant::now(); - let peer = iter.closest_peers.values().next().unwrap().key.clone().into_preimage(); + let peer = iter + .closest_peers + .values() + .next() + .unwrap() + .key + .clone() + .into_preimage(); // Poll the iterator for the first peer to be in progress. match iter.next(now) { PeersIterState::Waiting(Some(id)) => assert_eq!(&*id, &peer), - _ => panic!() + _ => panic!(), } // Artificially advance the clock. @@ -692,10 +734,13 @@ mod tests { // Advancing the iterator again should mark the first peer as unresponsive. let _ = iter.next(now); match &iter.closest_peers.values().next().unwrap() { - Peer { key, state: PeerState::Unresponsive } => { + Peer { + key, + state: PeerState::Unresponsive, + } => { assert_eq!(key.preimage(), &peer); - }, - Peer { state, .. } => panic!("Unexpected peer state: {:?}", state) + } + Peer { state, .. } => panic!("Unexpected peer state: {:?}", state), } let finished = iter.is_finished(); @@ -727,7 +772,7 @@ mod tests { PeersIterState::Waiting(Some(p)) => { let peer = p.clone().into_owned(); iter.on_failure(&peer); - }, + } _ => panic!("Expected iterator to yield another peer to query."), } } @@ -751,10 +796,8 @@ mod tests { ) } - iter.num_waiting = usize::max( - iter.config.parallelism.get(), - iter.config.num_results.get(), - ); + iter.num_waiting = + usize::max(iter.config.parallelism.get(), iter.config.num_results.get()); assert!( iter.at_capacity(), "Iterator should be at capacity if `max(parallelism, num_results)` requests are \ diff --git a/protocols/kad/src/query/peers/closest/disjoint.rs b/protocols/kad/src/query/peers/closest/disjoint.rs index b295355634b..01506ff6f7b 100644 --- a/protocols/kad/src/query/peers/closest/disjoint.rs +++ b/protocols/kad/src/query/peers/closest/disjoint.rs @@ -72,7 +72,10 @@ impl ClosestDisjointPeersIter { I: IntoIterator>, T: Into + Clone, { - let peers = known_closest_peers.into_iter().take(K_VALUE.get()).collect::>(); + let peers = known_closest_peers + .into_iter() + .take(K_VALUE.get()) + .collect::>(); let iters = (0..config.parallelism.get()) // NOTE: All [`ClosestPeersIter`] share the same set of peers at // initialization. The [`ClosestDisjointPeersIter.contacted_peers`] @@ -88,7 +91,9 @@ impl ClosestDisjointPeersIter { config, target: target.into(), iters, - iter_order: (0..iters_len).map(IteratorIndex as fn(usize) -> IteratorIndex).cycle(), + iter_order: (0..iters_len) + .map(IteratorIndex as fn(usize) -> IteratorIndex) + .cycle(), contacted_peers: HashMap::new(), } } @@ -106,7 +111,11 @@ impl ClosestDisjointPeersIter { pub fn on_failure(&mut self, peer: &PeerId) -> bool { let mut updated = false; - if let Some(PeerState{ initiated_by, response }) = self.contacted_peers.get_mut(peer) { + if let Some(PeerState { + initiated_by, + response, + }) = self.contacted_peers.get_mut(peer) + { updated = self.iters[*initiated_by].on_failure(peer); if updated { @@ -148,7 +157,11 @@ impl ClosestDisjointPeersIter { { let mut updated = false; - if let Some(PeerState{ initiated_by, response }) = self.contacted_peers.get_mut(peer) { + if let Some(PeerState { + initiated_by, + response, + }) = self.contacted_peers.get_mut(peer) + { // Pass the new `closer_peers` to the iterator that first yielded // the peer. updated = self.iters[*initiated_by].on_success(peer, closer_peers); @@ -185,7 +198,7 @@ impl ClosestDisjointPeersIter { let mut state = None; // Ensure querying each iterator at most once. - for _ in 0 .. self.iters.len() { + for _ in 0..self.iters.len() { let i = self.iter_order.next().expect("Cycle never ends."); let iter = &mut self.iters[i]; @@ -198,7 +211,7 @@ impl ClosestDisjointPeersIter { // [`ClosestPeersIter`] yielded a peer. Thus this state is // unreachable. unreachable!(); - }, + } Some(PeersIterState::Waiting(None)) => {} Some(PeersIterState::WaitingAtCapacity) => { // At least one ClosestPeersIter is no longer at capacity, thus the @@ -210,14 +223,13 @@ impl ClosestDisjointPeersIter { unreachable!(); } None => state = Some(PeersIterState::Waiting(None)), - }; break; } PeersIterState::Waiting(Some(peer)) => { match self.contacted_peers.get_mut(&*peer) { - Some(PeerState{ response, .. }) => { + Some(PeerState { response, .. }) => { // Another iterator already contacted this peer. let peer = peer.into_owned(); @@ -225,27 +237,27 @@ impl ClosestDisjointPeersIter { // The iterator will be notified later whether the given node // was successfully contacted or not. See // [`ClosestDisjointPeersIter::on_success`] for details. - ResponseState::Waiting => {}, + ResponseState::Waiting => {} ResponseState::Succeeded => { // Given that iterator was not the first to contact the peer // it will not be made aware of the closer peers discovered // to uphold the S/Kademlia disjoint paths guarantee. See // [`ClosestDisjointPeersIter::on_success`] for details. iter.on_success(&peer, std::iter::empty()); - }, + } ResponseState::Failed => { iter.on_failure(&peer); - }, + } } - }, + } None => { // The iterator is the first to contact this peer. - self.contacted_peers.insert( - peer.clone().into_owned(), - PeerState::new(i), - ); - return PeersIterState::Waiting(Some(Cow::Owned(peer.into_owned()))); - }, + self.contacted_peers + .insert(peer.clone().into_owned(), PeerState::new(i)); + return PeersIterState::Waiting(Some(Cow::Owned( + peer.into_owned(), + ))); + } } } PeersIterState::WaitingAtCapacity => { @@ -255,13 +267,13 @@ impl ClosestDisjointPeersIter { // [`ClosestPeersIter`] yielded a peer. Thus this state is // unreachable. unreachable!(); - }, + } Some(PeersIterState::Waiting(None)) => {} Some(PeersIterState::WaitingAtCapacity) => {} Some(PeersIterState::Finished) => { // `state` is never set to `Finished`. unreachable!(); - }, + } None => state = Some(PeersIterState::WaitingAtCapacity), }; @@ -280,10 +292,10 @@ impl ClosestDisjointPeersIter { /// See [`crate::query::Query::try_finish`] for details. pub fn finish_paths<'a, I>(&mut self, peers: I) -> bool where - I: IntoIterator + I: IntoIterator, { for peer in peers { - if let Some(PeerState{ initiated_by, .. }) = self.contacted_peers.get_mut(peer) { + if let Some(PeerState { initiated_by, .. }) = self.contacted_peers.get_mut(peer) { self.iters[*initiated_by].finish(); } } @@ -312,7 +324,9 @@ impl ClosestDisjointPeersIter { /// differentiate benign from faulty paths it as well returns faulty /// peers and thus overall returns more than `num_results` peers. pub fn into_result(self) -> impl Iterator { - let result_per_path= self.iters.into_iter() + let result_per_path = self + .iters + .into_iter() .map(|iter| iter.into_result().map(Key::from)); ResultIter::new(self.target, result_per_path).map(Key::into_preimage) @@ -370,7 +384,8 @@ enum ResponseState { // // Note: This operates under the assumption that `I` is ordered. #[derive(Clone, Debug)] -struct ResultIter where +struct ResultIter +where I: Iterator>, { target: KeyBytes, @@ -379,7 +394,7 @@ struct ResultIter where impl>> ResultIter { fn new(target: KeyBytes, iters: impl Iterator) -> Self { - ResultIter{ + ResultIter { target, iters: iters.map(Iterator::peekable).collect(), } @@ -392,36 +407,34 @@ impl>> Iterator for ResultIter { fn next(&mut self) -> Option { let target = &self.target; - self.iters.iter_mut() + self.iters + .iter_mut() // Find the iterator with the next closest peer. - .fold( - Option::<&mut Peekable<_>>::None, - |iter_a, iter_b| { - let iter_a = match iter_a { - Some(iter_a) => iter_a, - None => return Some(iter_b), - }; - - match (iter_a.peek(), iter_b.peek()) { - (Some(next_a), Some(next_b)) => { - if next_a == next_b { - // Remove from one for deduplication. - iter_b.next(); - return Some(iter_a) - } + .fold(Option::<&mut Peekable<_>>::None, |iter_a, iter_b| { + let iter_a = match iter_a { + Some(iter_a) => iter_a, + None => return Some(iter_b), + }; + + match (iter_a.peek(), iter_b.peek()) { + (Some(next_a), Some(next_b)) => { + if next_a == next_b { + // Remove from one for deduplication. + iter_b.next(); + return Some(iter_a); + } - if target.distance(next_a) < target.distance(next_b) { - Some(iter_a) - } else { - Some(iter_b) - } - }, - (Some(_), None) => Some(iter_a), - (None, Some(_)) => Some(iter_b), - (None, None) => None, + if target.distance(next_a) < target.distance(next_b) { + Some(iter_a) + } else { + Some(iter_b) + } } - }, - ) + (Some(_), None) => Some(iter_a), + (None, Some(_)) => Some(iter_b), + (None, None) => None, + } + }) // Pop off the next closest peer from that iterator. .and_then(Iterator::next) } @@ -434,7 +447,7 @@ mod tests { use crate::K_VALUE; use libp2p_core::multihash::{Code, Multihash}; use quickcheck::*; - use rand::{Rng, seq::SliceRandom}; + use rand::{seq::SliceRandom, Rng}; use std::collections::HashSet; use std::iter; @@ -442,22 +455,18 @@ mod tests { fn arbitrary(g: &mut G) -> Self { let target = Target::arbitrary(g).0; let num_closest_iters = g.gen_range(0, 20 + 1); - let peers = random_peers( - g.gen_range(0, 20 * num_closest_iters + 1), - g, - ); + let peers = random_peers(g.gen_range(0, 20 * num_closest_iters + 1), g); let iters: Vec<_> = (0..num_closest_iters) .map(|_| { let num_peers = g.gen_range(0, 20 + 1); - let mut peers = peers.choose_multiple(g, num_peers) + let mut peers = peers + .choose_multiple(g, num_peers) .cloned() .map(Key::from) .collect::>(); - peers.sort_unstable_by(|a, b| { - target.distance(a).cmp(&target.distance(b)) - }); + peers.sort_unstable_by(|a, b| target.distance(a).cmp(&target.distance(b))); peers.into_iter() }) @@ -467,7 +476,8 @@ mod tests { } fn shrink(&self) -> Box> { - let peers = self.iters + let peers = self + .iters .clone() .into_iter() .flatten() @@ -475,7 +485,9 @@ mod tests { .into_iter() .collect::>(); - let iters = self.iters.clone() + let iters = self + .iters + .clone() .into_iter() .map(|iter| iter.collect::>()) .collect(); @@ -503,14 +515,18 @@ mod tests { // The peer that should not be included. let peer = self.peers.pop()?; - let iters = self.iters.clone().into_iter() + let iters = self + .iters + .clone() + .into_iter() .filter_map(|mut iter| { iter.retain(|p| p != &peer); if iter.is_empty() { return None; } Some(iter.into_iter()) - }).collect::>(); + }) + .collect::>(); Some(ResultIter::new(self.target.clone(), iters.into_iter())) } @@ -526,16 +542,22 @@ mod tests { } fn random_peers(n: usize, g: &mut R) -> Vec { - (0 .. n).map(|_| PeerId::from_multihash( - Multihash::wrap(Code::Sha2_256.into(), &g.gen::<[u8; 32]>()).unwrap() - ).unwrap()).collect() + (0..n) + .map(|_| { + PeerId::from_multihash( + Multihash::wrap(Code::Sha2_256.into(), &g.gen::<[u8; 32]>()).unwrap(), + ) + .unwrap() + }) + .collect() } #[test] fn result_iter_returns_deduplicated_ordered_peer_id_stream() { fn prop(result_iter: ResultIter>>) { let expected = { - let mut deduplicated = result_iter.clone() + let mut deduplicated = result_iter + .clone() .iters .into_iter() .flatten() @@ -545,7 +567,10 @@ mod tests { .collect::>(); deduplicated.sort_unstable_by(|a, b| { - result_iter.target.distance(a).cmp(&result_iter.target.distance(b)) + result_iter + .target + .distance(a) + .cmp(&result_iter.target.distance(b)) }); deduplicated @@ -560,7 +585,7 @@ mod tests { #[derive(Debug, Clone)] struct Parallelism(NonZeroUsize); - impl Arbitrary for Parallelism{ + impl Arbitrary for Parallelism { fn arbitrary(g: &mut G) -> Self { Parallelism(NonZeroUsize::new(g.gen_range(1, 10)).unwrap()) } @@ -569,7 +594,7 @@ mod tests { #[derive(Debug, Clone)] struct NumResults(NonZeroUsize); - impl Arbitrary for NumResults{ + impl Arbitrary for NumResults { fn arbitrary(g: &mut G) -> Self { NumResults(NonZeroUsize::new(g.gen_range(1, K_VALUE.get())).unwrap()) } @@ -604,13 +629,12 @@ mod tests { let now = Instant::now(); let target: KeyBytes = Key::from(PeerId::random()).into(); - let mut pool = [0; 12].iter() + let mut pool = [0; 12] + .iter() .map(|_| Key::from(PeerId::random())) .collect::>(); - pool.sort_unstable_by(|a, b| { - target.distance(a).cmp(&target.distance(b)) - }); + pool.sort_unstable_by(|a, b| target.distance(a).cmp(&target.distance(b))); let known_closest_peers = pool.split_off(pool.len() - 3); @@ -637,10 +661,7 @@ mod tests { } } - assert_eq!( - PeersIterState::WaitingAtCapacity, - peers_iter.next(now), - ); + assert_eq!(PeersIterState::WaitingAtCapacity, peers_iter.next(now),); let response_2 = pool.split_off(pool.len() - 3); let response_3 = pool.split_off(pool.len() - 3); @@ -651,7 +672,10 @@ mod tests { // Response from malicious peer 1. peers_iter.on_success( known_closest_peers[0].preimage(), - malicious_response_1.clone().into_iter().map(|k| k.preimage().clone()), + malicious_response_1 + .clone() + .into_iter() + .map(|k| k.preimage().clone()), ); // Response from peer 2. @@ -676,7 +700,7 @@ mod tests { } else { panic!("Expected iterator to return peer to query."); } - }; + } // Expect a peer from each disjoint path. assert!(next_to_query.contains(malicious_response_1[0].preimage())); @@ -696,10 +720,7 @@ mod tests { } } - assert_eq!( - PeersIterState::Finished, - peers_iter.next(now), - ); + assert_eq!(PeersIterState::Finished, peers_iter.next(now),); let final_peers: Vec<_> = peers_iter.into_result().collect(); @@ -715,7 +736,9 @@ mod tests { impl std::fmt::Debug for Graph { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - fmt.debug_list().entries(self.0.iter().map(|(id, _)| id)).finish() + fmt.debug_list() + .entries(self.0.iter().map(|(id, _)| id)) + .finish() } } @@ -727,22 +750,24 @@ mod tests { .collect::>(); // Make each peer aware of its direct neighborhood. - let mut peers = peer_ids.clone().into_iter() + let mut peers = peer_ids + .clone() + .into_iter() .map(|(peer_id, key)| { - peer_ids.sort_unstable_by(|(_, a), (_, b)| { - key.distance(a).cmp(&key.distance(b)) - }); + peer_ids + .sort_unstable_by(|(_, a), (_, b)| key.distance(a).cmp(&key.distance(b))); assert_eq!(peer_id, peer_ids[0].0); - let known_peers = peer_ids.iter() + let known_peers = peer_ids + .iter() // Skip itself. .skip(1) .take(K_VALUE.get()) .cloned() .collect::>(); - (peer_id, Peer{ known_peers }) + (peer_id, Peer { known_peers }) }) .collect::>(); @@ -751,7 +776,8 @@ mod tests { peer_ids.shuffle(g); let num_peers = g.gen_range(K_VALUE.get(), peer_ids.len() + 1); - let mut random_peer_ids = peer_ids.choose_multiple(g, num_peers) + let mut random_peer_ids = peer_ids + .choose_multiple(g, num_peers) // Make sure not to include itself. .filter(|(id, _)| peer_id != id) .cloned() @@ -760,7 +786,10 @@ mod tests { peer.known_peers.append(&mut random_peer_ids); peer.known_peers = std::mem::replace(&mut peer.known_peers, vec![]) // Deduplicate peer ids. - .into_iter().collect::>().into_iter().collect(); + .into_iter() + .collect::>() + .into_iter() + .collect(); } Graph(peers) @@ -769,21 +798,22 @@ mod tests { impl Graph { fn get_closest_peer(&self, target: &KeyBytes) -> PeerId { - self.0.iter() + self.0 + .iter() .map(|(peer_id, _)| (target.distance(&Key::from(*peer_id)), peer_id)) - .fold(None, |acc, (distance_b, peer_id_b)| { - match acc { - None => Some((distance_b, peer_id_b)), - Some((distance_a, peer_id_a)) => if distance_a < distance_b { + .fold(None, |acc, (distance_b, peer_id_b)| match acc { + None => Some((distance_b, peer_id_b)), + Some((distance_a, peer_id_a)) => { + if distance_a < distance_b { Some((distance_a, peer_id_a)) } else { Some((distance_b, peer_id_b)) } } - }) .expect("Graph to have at least one peer.") - .1.clone() + .1 + .clone() } } @@ -794,11 +824,15 @@ mod tests { impl Peer { fn get_closest_peers(&mut self, target: &KeyBytes) -> Vec { - self.known_peers.sort_unstable_by(|(_, a), (_, b)| { - target.distance(a).cmp(&target.distance(b)) - }); + self.known_peers + .sort_unstable_by(|(_, a), (_, b)| target.distance(a).cmp(&target.distance(b))); - self.known_peers.iter().take(K_VALUE.get()).map(|(id, _)| id).cloned().collect() + self.known_peers + .iter() + .take(K_VALUE.get()) + .map(|(id, _)| id) + .cloned() + .collect() } } @@ -846,15 +880,16 @@ mod tests { let target: KeyBytes = target.0; let closest_peer = graph.get_closest_peer(&target); - let mut known_closest_peers = graph.0.iter() + let mut known_closest_peers = graph + .0 + .iter() .take(K_VALUE.get()) .map(|(key, _peers)| Key::from(*key)) .collect::>(); - known_closest_peers.sort_unstable_by(|a, b| { - target.distance(a).cmp(&target.distance(b)) - }); + known_closest_peers + .sort_unstable_by(|a, b| target.distance(a).cmp(&target.distance(b))); - let cfg = ClosestPeersIterConfig{ + let cfg = ClosestPeersIterConfig { parallelism: parallelism.0, num_results: num_results.0, ..ClosestPeersIterConfig::default() @@ -923,25 +958,32 @@ mod tests { match iter.next(now) { PeersIterState::Waiting(Some(peer_id)) => { let peer_id = peer_id.clone().into_owned(); - let closest_peers = graph.0.get_mut(&peer_id) + let closest_peers = graph + .0 + .get_mut(&peer_id) .unwrap() .get_closest_peers(&target); iter.on_success(&peer_id, closest_peers); - } , - PeersIterState::WaitingAtCapacity | PeersIterState::Waiting(None) => - panic!("There is never more than one request in flight."), + } + PeersIterState::WaitingAtCapacity | PeersIterState::Waiting(None) => { + panic!("There is never more than one request in flight.") + } PeersIterState::Finished => break, } } - let mut result = iter.into_result().into_iter().map(Key::from).collect::>(); - result.sort_unstable_by(|a, b| { - target.distance(a).cmp(&target.distance(b)) - }); + let mut result = iter + .into_result() + .into_iter() + .map(Key::from) + .collect::>(); + result.sort_unstable_by(|a, b| target.distance(a).cmp(&target.distance(b))); result.into_iter().map(|k| k.into_preimage()).collect() } - QuickCheck::new().tests(10).quickcheck(prop as fn(_, _, _, _) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _, _, _) -> _) } #[test] @@ -957,16 +999,22 @@ mod tests { // Expect peer to be marked as succeeded. assert!(iter.on_success(&peer, iter::empty())); - assert_eq!(iter.contacted_peers.get(&peer), Some(&PeerState { - initiated_by: IteratorIndex(0), - response: ResponseState::Succeeded, - })); + assert_eq!( + iter.contacted_peers.get(&peer), + Some(&PeerState { + initiated_by: IteratorIndex(0), + response: ResponseState::Succeeded, + }) + ); // Expect peer to stay marked as succeeded. assert!(!iter.on_failure(&peer)); - assert_eq!(iter.contacted_peers.get(&peer), Some(&PeerState { - initiated_by: IteratorIndex(0), - response: ResponseState::Succeeded, - })); + assert_eq!( + iter.contacted_peers.get(&peer), + Some(&PeerState { + initiated_by: IteratorIndex(0), + response: ResponseState::Succeeded, + }) + ); } } diff --git a/protocols/kad/src/query/peers/fixed.rs b/protocols/kad/src/query/peers/fixed.rs index b816ea9ce0f..e4be4094eb1 100644 --- a/protocols/kad/src/query/peers/fixed.rs +++ b/protocols/kad/src/query/peers/fixed.rs @@ -22,7 +22,7 @@ use super::*; use fnv::FnvHashMap; use libp2p_core::PeerId; -use std::{vec, collections::hash_map::Entry, num::NonZeroUsize}; +use std::{collections::hash_map::Entry, num::NonZeroUsize, vec}; /// A peer iterator for a fixed set of peers. pub struct FixedPeersIter { @@ -42,7 +42,7 @@ pub struct FixedPeersIter { #[derive(Debug, PartialEq, Eq)] enum State { Waiting { num_waiting: usize }, - Finished + Finished, } #[derive(Copy, Clone, PartialEq, Eq)] @@ -60,7 +60,7 @@ enum PeerState { impl FixedPeersIter { pub fn new(peers: I, parallelism: NonZeroUsize) -> Self where - I: IntoIterator + I: IntoIterator, { let peers = peers.into_iter().collect::>(); @@ -87,7 +87,7 @@ impl FixedPeersIter { if let Some(state @ PeerState::Waiting) = self.peers.get_mut(peer) { *state = PeerState::Succeeded; *num_waiting -= 1; - return true + return true; } } false @@ -108,7 +108,7 @@ impl FixedPeersIter { if let Some(state @ PeerState::Waiting) = self.peers.get_mut(peer) { *state = PeerState::Failed; *num_waiting -= 1; - return true + return true; } } false @@ -134,24 +134,26 @@ impl FixedPeersIter { State::Finished => PeersIterState::Finished, State::Waiting { num_waiting } => { if *num_waiting >= self.parallelism.get() { - return PeersIterState::WaitingAtCapacity + return PeersIterState::WaitingAtCapacity; } loop { match self.iter.next() { - None => if *num_waiting == 0 { - self.state = State::Finished; - return PeersIterState::Finished - } else { - return PeersIterState::Waiting(None) + None => { + if *num_waiting == 0 { + self.state = State::Finished; + return PeersIterState::Finished; + } else { + return PeersIterState::Waiting(None); + } } Some(p) => match self.peers.entry(p) { Entry::Occupied(_) => {} // skip duplicates Entry::Vacant(e) => { *num_waiting += 1; e.insert(PeerState::Waiting); - return PeersIterState::Waiting(Some(Cow::Owned(p))) + return PeersIterState::Waiting(Some(Cow::Owned(p))); } - } + }, } } } @@ -159,13 +161,13 @@ impl FixedPeersIter { } pub fn into_result(self) -> impl Iterator { - self.peers.into_iter() - .filter_map(|(p, s)| - if let PeerState::Succeeded = s { - Some(p) - } else { - None - }) + self.peers.into_iter().filter_map(|(p, s)| { + if let PeerState::Succeeded = s { + Some(p) + } else { + None + } + }) } } @@ -184,12 +186,12 @@ mod test { PeersIterState::Waiting(Some(peer)) => { let peer = peer.into_owned(); iter.on_failure(&peer); - }, + } _ => panic!("Expected iterator to yield peer."), } match iter.next() { - PeersIterState::Waiting(Some(_)) => {}, + PeersIterState::Waiting(Some(_)) => {} PeersIterState::WaitingAtCapacity => panic!( "Expected iterator to return another peer given that the \ previous `on_failure` call should have allowed another peer \ diff --git a/protocols/kad/src/record.rs b/protocols/kad/src/record.rs index 5a15fdd1034..8f1c585d1b8 100644 --- a/protocols/kad/src/record.rs +++ b/protocols/kad/src/record.rs @@ -23,7 +23,7 @@ pub mod store; use bytes::Bytes; -use libp2p_core::{PeerId, Multiaddr, multihash::Multihash}; +use libp2p_core::{multihash::Multihash, Multiaddr, PeerId}; use std::borrow::Borrow; use std::hash::{Hash, Hasher}; use wasm_timer::Instant; @@ -85,7 +85,7 @@ impl Record { /// Creates a new record for insertion into the DHT. pub fn new(key: K, value: Vec) -> Self where - K: Into + K: Into, { Record { key: key.into(), @@ -116,7 +116,7 @@ pub struct ProviderRecord { /// The expiration time as measured by a local, monotonic clock. pub expires: Option, /// The known addresses that the provider may be listening on. - pub addresses: Vec + pub addresses: Vec, } impl Hash for ProviderRecord { @@ -138,7 +138,7 @@ impl ProviderRecord { /// Creates a new provider record for insertion into a `RecordStore`. pub fn new(key: K, provider: PeerId, addresses: Vec) -> Self where - K: Into + K: Into, { ProviderRecord { key: key.into(), @@ -157,8 +157,8 @@ impl ProviderRecord { #[cfg(test)] mod tests { use super::*; - use quickcheck::*; use libp2p_core::multihash::Code; + use quickcheck::*; use rand::Rng; use std::time::Duration; @@ -174,7 +174,11 @@ mod tests { Record { key: Key::arbitrary(g), value: Vec::arbitrary(g), - publisher: if g.gen() { Some(PeerId::random()) } else { None }, + publisher: if g.gen() { + Some(PeerId::random()) + } else { + None + }, expires: if g.gen() { Some(Instant::now() + Duration::from_secs(g.gen_range(0, 60))) } else { diff --git a/protocols/kad/src/record/store.rs b/protocols/kad/src/record/store.rs index 82402ed3c18..9347afedd7c 100644 --- a/protocols/kad/src/record/store.rs +++ b/protocols/kad/src/record/store.rs @@ -22,8 +22,8 @@ mod memory; pub use memory::{MemoryStore, MemoryStoreConfig}; -use crate::K_VALUE; use super::*; +use crate::K_VALUE; use std::borrow::Cow; /// The result of an operation on a `RecordStore`. @@ -92,4 +92,3 @@ pub trait RecordStore<'a> { /// Removes a provider record from the store. fn remove_provider(&'a mut self, k: &Key, p: &PeerId); } - diff --git a/protocols/kad/src/record/store/memory.rs b/protocols/kad/src/record/store/memory.rs index d74f32bdfbf..c6a006b6cd5 100644 --- a/protocols/kad/src/record/store/memory.rs +++ b/protocols/kad/src/record/store/memory.rs @@ -90,21 +90,19 @@ impl MemoryStore { /// Retains the records satisfying a predicate. pub fn retain(&mut self, f: F) where - F: FnMut(&Key, &mut Record) -> bool + F: FnMut(&Key, &mut Record) -> bool, { self.records.retain(f); } } impl<'a> RecordStore<'a> for MemoryStore { - type RecordsIter = iter::Map< - hash_map::Values<'a, Key, Record>, - fn(&'a Record) -> Cow<'a, Record> - >; + type RecordsIter = + iter::Map, fn(&'a Record) -> Cow<'a, Record>>; type ProvidedIter = iter::Map< hash_set::Iter<'a, ProviderRecord>, - fn(&'a ProviderRecord) -> Cow<'a, ProviderRecord> + fn(&'a ProviderRecord) -> Cow<'a, ProviderRecord>, >; fn get(&'a self, k: &Key) -> Option> { @@ -113,7 +111,7 @@ impl<'a> RecordStore<'a> for MemoryStore { fn put(&'a mut self, r: Record) -> Result<()> { if r.value.len() >= self.config.max_value_bytes { - return Err(Error::ValueTooLarge) + return Err(Error::ValueTooLarge); } let num_records = self.records.len(); @@ -124,7 +122,7 @@ impl<'a> RecordStore<'a> for MemoryStore { } hash_map::Entry::Vacant(e) => { if num_records >= self.config.max_records { - return Err(Error::MaxRecords) + return Err(Error::MaxRecords); } e.insert(r); } @@ -146,14 +144,15 @@ impl<'a> RecordStore<'a> for MemoryStore { // Obtain the entry let providers = match self.providers.entry(record.key.clone()) { - e@hash_map::Entry::Occupied(_) => e, - e@hash_map::Entry::Vacant(_) => { + e @ hash_map::Entry::Occupied(_) => e, + e @ hash_map::Entry::Vacant(_) => { if self.config.max_provided_keys == num_keys { - return Err(Error::MaxProvidedKeys) + return Err(Error::MaxProvidedKeys); } e } - }.or_insert_with(Default::default); + } + .or_insert_with(Default::default); if let Some(i) = providers.iter().position(|p| p.provider == record.provider) { // In-place update of an existing provider record. @@ -178,8 +177,7 @@ impl<'a> RecordStore<'a> for MemoryStore { self.provided.remove(&p); } } - } - else if providers.len() < self.config.max_providers_per_key { + } else if providers.len() < self.config.max_providers_per_key { // The distance of the new provider to the key is larger than // the distance of any existing provider, but there is still room. if local_key.preimage() == &record.provider { @@ -192,7 +190,9 @@ impl<'a> RecordStore<'a> for MemoryStore { } fn providers(&'a self, key: &Key) -> Vec { - self.providers.get(key).map_or_else(Vec::new, |ps| ps.clone().into_vec()) + self.providers + .get(key) + .map_or_else(Vec::new, |ps| ps.clone().into_vec()) } fn provided(&'a self) -> Self::ProvidedIter { @@ -225,8 +225,7 @@ mod tests { } fn distance(r: &ProviderRecord) -> kbucket::Distance { - kbucket::Key::new(r.key.clone()) - .distance(&kbucket::Key::from(r.provider)) + kbucket::Key::new(r.key.clone()).distance(&kbucket::Key::from(r.provider)) } #[test] @@ -259,9 +258,10 @@ mod tests { let mut store = MemoryStore::new(PeerId::random()); let key = Key::from(random_multihash()); - let mut records = providers.into_iter().map(|p| { - ProviderRecord::new(key.clone(), p.into_preimage(), Vec::new()) - }).collect::>(); + let mut records = providers + .into_iter() + .map(|p| ProviderRecord::new(key.clone(), p.into_preimage(), Vec::new())) + .collect::>(); for r in &records { assert!(store.add_provider(r.clone()).is_ok()); @@ -283,7 +283,10 @@ mod tests { let key = random_multihash(); let rec = ProviderRecord::new(key, id.clone(), Vec::new()); assert!(store.add_provider(rec.clone()).is_ok()); - assert_eq!(vec![Cow::Borrowed(&rec)], store.provided().collect::>()); + assert_eq!( + vec![Cow::Borrowed(&rec)], + store.provided().collect::>() + ); store.remove_provider(&rec.key, &id); assert_eq!(store.provided().count(), 0); } @@ -304,7 +307,7 @@ mod tests { #[test] fn max_provided_keys() { let mut store = MemoryStore::new(PeerId::random()); - for _ in 0 .. store.config.max_provided_keys { + for _ in 0..store.config.max_provided_keys { let key = random_multihash(); let prv = PeerId::random(); let rec = ProviderRecord::new(key, prv, Vec::new()); diff --git a/protocols/mdns/src/behaviour.rs b/protocols/mdns/src/behaviour.rs index 7348227c4bb..2a170e2d839 100644 --- a/protocols/mdns/src/behaviour.rs +++ b/protocols/mdns/src/behaviour.rs @@ -18,16 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::IPV4_MDNS_MULTICAST_ADDRESS; use crate::dns::{build_query, build_query_response, build_service_discovery_response}; use crate::query::MdnsPacket; +use crate::IPV4_MDNS_MULTICAST_ADDRESS; use async_io::{Async, Timer}; use futures::prelude::*; use if_watch::{IfEvent, IfWatcher}; use libp2p_core::connection::ListenerId; -use libp2p_core::{ - address_translation, multiaddr::Protocol, Multiaddr, PeerId, -}; +use libp2p_core::{address_translation, multiaddr::Protocol, Multiaddr, PeerId}; use libp2p_swarm::{ protocols_handler::DummyProtocolsHandler, NetworkBehaviour, NetworkBehaviourAction, PollParameters, ProtocolsHandler, diff --git a/protocols/ping/src/handler.rs b/protocols/ping/src/handler.rs index ebfc5a0b1a5..1c4233e2b22 100644 --- a/protocols/ping/src/handler.rs +++ b/protocols/ping/src/handler.rs @@ -19,28 +19,23 @@ // DEALINGS IN THE SOFTWARE. use crate::protocol; -use futures::prelude::*; use futures::future::BoxFuture; -use libp2p_core::{UpgradeError, upgrade::NegotiationError}; +use futures::prelude::*; +use libp2p_core::{upgrade::NegotiationError, UpgradeError}; use libp2p_swarm::{ - KeepAlive, - NegotiatedSubstream, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerUpgrErr, - ProtocolsHandlerEvent + KeepAlive, NegotiatedSubstream, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use std::collections::VecDeque; use std::{ error::Error, - io, - fmt, + fmt, io, num::NonZeroU32, task::{Context, Poll}, - time::Duration + time::Duration, }; -use std::collections::VecDeque; -use wasm_timer::Delay; use void::Void; +use wasm_timer::Delay; /// The configuration for outbound pings. #[derive(Clone, Debug)] @@ -82,7 +77,7 @@ impl PingConfig { timeout: Duration::from_secs(20), interval: Duration::from_secs(15), max_failures: NonZeroU32::new(1).expect("1 != 0"), - keep_alive: false + keep_alive: false, } } @@ -144,7 +139,9 @@ pub enum PingFailure { /// The peer does not support the ping protocol. Unsupported, /// The ping failed for reasons other than a timeout. - Other { error: Box } + Other { + error: Box, + }, } impl fmt::Display for PingFailure { @@ -190,7 +187,7 @@ pub struct PingHandler { /// next inbound ping to be answered. inbound: Option, /// Tracks the state of our handler. - state: State + state: State, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -200,7 +197,7 @@ enum State { /// Whether or not we've reported the missing support yet. /// /// This is used to avoid repeated events being emitted for a specific connection. - reported: bool + reported: bool, }, /// We are actively pinging the other peer. Active, @@ -252,11 +249,9 @@ impl ProtocolsHandler for PingHandler { ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { debug_assert_eq!(self.state, State::Active); - self.state = State::Inactive { - reported: false - }; + self.state = State::Inactive { reported: false }; return; - }, + } // Note: This timeout only covers protocol negotiation. ProtocolsHandlerUpgrErr::Timeout => PingFailure::Timeout, e => PingFailure::Other { error: Box::new(e) }, @@ -273,22 +268,25 @@ impl ProtocolsHandler for PingHandler { } } - fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { match self.state { State::Inactive { reported: true } => { - return Poll::Pending // nothing to do on this connection - }, + return Poll::Pending; // nothing to do on this connection + } State::Inactive { reported: false } => { self.state = State::Inactive { reported: true }; return Poll::Ready(ProtocolsHandlerEvent::Custom(Err(PingFailure::Unsupported))); - }, + } State::Active => {} } // Respond to inbound pings. if let Some(fut) = self.inbound.as_mut() { match fut.poll_unpin(cx) { - Poll::Pending => {}, + Poll::Pending => {} Poll::Ready(Err(e)) => { log::debug!("Inbound ping error: {:?}", e); self.inbound = None; @@ -296,7 +294,7 @@ impl ProtocolsHandler for PingHandler { Poll::Ready(Ok(stream)) => { // A ping from a remote peer has been answered, wait for the next. self.inbound = Some(protocol::recv_ping(stream).boxed()); - return Poll::Ready(ProtocolsHandlerEvent::Custom(Ok(PingSuccess::Pong))) + return Poll::Ready(ProtocolsHandlerEvent::Custom(Ok(PingSuccess::Pong))); } } } @@ -318,10 +316,10 @@ impl ProtocolsHandler for PingHandler { if self.failures > 1 || self.config.max_failures.get() > 1 { if self.failures >= self.config.max_failures.get() { log::debug!("Too many failures ({}). Closing connection.", self.failures); - return Poll::Ready(ProtocolsHandlerEvent::Close(error)) + return Poll::Ready(ProtocolsHandlerEvent::Close(error)); } - return Poll::Ready(ProtocolsHandlerEvent::Custom(Err(error))) + return Poll::Ready(ProtocolsHandlerEvent::Custom(Err(error))); } } @@ -333,50 +331,48 @@ impl ProtocolsHandler for PingHandler { self.pending_errors.push_front(PingFailure::Timeout); } else { self.outbound = Some(PingState::Ping(ping)); - break + break; } - }, + } Poll::Ready(Ok((stream, rtt))) => { self.failures = 0; self.timer.reset(self.config.interval); self.outbound = Some(PingState::Idle(stream)); - return Poll::Ready( - ProtocolsHandlerEvent::Custom( - Ok(PingSuccess::Ping { rtt }))) + return Poll::Ready(ProtocolsHandlerEvent::Custom(Ok(PingSuccess::Ping { + rtt, + }))); } Poll::Ready(Err(e)) => { - self.pending_errors.push_front(PingFailure::Other { - error: Box::new(e) - }); + self.pending_errors + .push_front(PingFailure::Other { error: Box::new(e) }); } }, Some(PingState::Idle(stream)) => match self.timer.poll_unpin(cx) { Poll::Pending => { self.outbound = Some(PingState::Idle(stream)); - break - }, + break; + } Poll::Ready(Ok(())) => { self.timer.reset(self.config.timeout); self.outbound = Some(PingState::Ping(protocol::send_ping(stream).boxed())); - }, + } Poll::Ready(Err(e)) => { - return Poll::Ready(ProtocolsHandlerEvent::Close( - PingFailure::Other { - error: Box::new(e) - })) + return Poll::Ready(ProtocolsHandlerEvent::Close(PingFailure::Other { + error: Box::new(e), + })) } - } + }, Some(PingState::OpenStream) => { self.outbound = Some(PingState::OpenStream); - break + break; } None => { self.outbound = Some(PingState::OpenStream); let protocol = SubstreamProtocol::new(protocol::Ping, ()) .with_timeout(self.config.timeout); return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol - }) + protocol, + }); } } } diff --git a/protocols/ping/src/lib.rs b/protocols/ping/src/lib.rs index cd9cc227c9d..d4e3828f430 100644 --- a/protocols/ping/src/lib.rs +++ b/protocols/ping/src/lib.rs @@ -40,13 +40,13 @@ //! [`Swarm`]: libp2p_swarm::Swarm //! [`Transport`]: libp2p_core::Transport -pub mod protocol; pub mod handler; +pub mod protocol; -pub use handler::{PingConfig, PingResult, PingSuccess, PingFailure}; use handler::PingHandler; +pub use handler::{PingConfig, PingFailure, PingResult, PingSuccess}; -use libp2p_core::{PeerId, connection::ConnectionId}; +use libp2p_core::{connection::ConnectionId, PeerId}; use libp2p_swarm::{NetworkBehaviour, NetworkBehaviourAction, PollParameters}; use std::{collections::VecDeque, task::Context, task::Poll}; use void::Void; @@ -99,9 +99,11 @@ impl NetworkBehaviour for Ping { self.events.push_front(PingEvent { peer, result }) } - fn poll(&mut self, _: &mut Context<'_>, _: &mut impl PollParameters) - -> Poll> - { + fn poll( + &mut self, + _: &mut Context<'_>, + _: &mut impl PollParameters, + ) -> Poll> { if let Some(e) = self.events.pop_back() { Poll::Ready(NetworkBehaviourAction::GenerateEvent(e)) } else { diff --git a/protocols/ping/src/protocol.rs b/protocols/ping/src/protocol.rs index aa63833f651..a3138568777 100644 --- a/protocols/ping/src/protocol.rs +++ b/protocols/ping/src/protocol.rs @@ -82,7 +82,7 @@ impl OutboundUpgrade for Ping { /// Sends a ping and waits for the pong. pub async fn send_ping(mut stream: S) -> io::Result<(S, Duration)> where - S: AsyncRead + AsyncWrite + Unpin + S: AsyncRead + AsyncWrite + Unpin, { let payload: [u8; PING_SIZE] = thread_rng().sample(distributions::Standard); log::debug!("Preparing ping payload {:?}", payload); @@ -95,14 +95,17 @@ where if recv_payload == payload { Ok((stream, started.elapsed())) } else { - Err(io::Error::new(io::ErrorKind::InvalidData, "Ping payload mismatch")) + Err(io::Error::new( + io::ErrorKind::InvalidData, + "Ping payload mismatch", + )) } } /// Waits for a ping and sends a pong. pub async fn recv_ping(mut stream: S) -> io::Result where - S: AsyncRead + AsyncWrite + Unpin + S: AsyncRead + AsyncWrite + Unpin, { let mut payload = [0u8; PING_SIZE]; log::debug!("Waiting for ping ..."); @@ -118,11 +121,7 @@ mod tests { use super::*; use libp2p_core::{ multiaddr::multiaddr, - transport::{ - Transport, - ListenerEvent, - memory::MemoryTransport - } + transport::{memory::MemoryTransport, ListenerEvent, Transport}, }; use rand::{thread_rng, Rng}; use std::time::Duration; diff --git a/protocols/ping/tests/ping.rs b/protocols/ping/tests/ping.rs index fcd013352ab..3f10ace782b 100644 --- a/protocols/ping/tests/ping.rs +++ b/protocols/ping/tests/ping.rs @@ -20,13 +20,12 @@ //! Integration tests for the `Ping` network behaviour. +use futures::{channel::mpsc, prelude::*}; use libp2p_core::{ - Multiaddr, - PeerId, identity, muxing::StreamMuxerBox, transport::{self, Transport}, - upgrade + upgrade, Multiaddr, PeerId, }; use libp2p_mplex as mplex; use libp2p_noise as noise; @@ -34,7 +33,6 @@ use libp2p_ping::*; use libp2p_swarm::{DummyBehaviour, KeepAlive, Swarm, SwarmEvent}; use libp2p_tcp::TcpConfig; use libp2p_yamux as yamux; -use futures::{prelude::*, channel::mpsc}; use quickcheck::*; use rand::prelude::*; use std::{num::NonZeroU8, time::Duration}; @@ -65,13 +63,18 @@ fn ping_pong() { loop { match swarm1.select_next_some().await { SwarmEvent::NewListenAddr { address, .. } => tx.send(address).await.unwrap(), - SwarmEvent::Behaviour(PingEvent { peer, result: Ok(PingSuccess::Ping { rtt }) }) => { + SwarmEvent::Behaviour(PingEvent { + peer, + result: Ok(PingSuccess::Ping { rtt }), + }) => { count1 -= 1; if count1 == 0 { - return (pid1.clone(), peer, rtt) + return (pid1.clone(), peer, rtt); } - }, - SwarmEvent::Behaviour(PingEvent { result: Err(e), .. }) => panic!("Ping failure: {:?}", e), + } + SwarmEvent::Behaviour(PingEvent { result: Err(e), .. }) => { + panic!("Ping failure: {:?}", e) + } _ => {} } } @@ -85,17 +88,16 @@ fn ping_pong() { match swarm2.select_next_some().await { SwarmEvent::Behaviour(PingEvent { peer, - result: Ok(PingSuccess::Ping { rtt }) + result: Ok(PingSuccess::Ping { rtt }), }) => { count2 -= 1; if count2 == 0 { - return (pid2.clone(), peer, rtt) + return (pid2.clone(), peer, rtt); } - }, - SwarmEvent::Behaviour(PingEvent { - result: Err(e), - .. - }) => panic!("Ping failure: {:?}", e), + } + SwarmEvent::Behaviour(PingEvent { result: Err(e), .. }) => { + panic!("Ping failure: {:?}", e) + } _ => {} } } @@ -107,7 +109,7 @@ fn ping_pong() { assert!(rtt < Duration::from_millis(50)); } - QuickCheck::new().tests(10).quickcheck(prop as fn(_,_)) + QuickCheck::new().tests(10).quickcheck(prop as fn(_, _)) } /// Tests that the connection is closed upon a configurable @@ -139,18 +141,15 @@ fn max_failures() { match swarm1.select_next_some().await { SwarmEvent::NewListenAddr { address, .. } => tx.send(address).await.unwrap(), SwarmEvent::Behaviour(PingEvent { - result: Ok(PingSuccess::Ping { .. }), .. + result: Ok(PingSuccess::Ping { .. }), + .. }) => { count1 = 0; // there may be an occasional success } - SwarmEvent::Behaviour(PingEvent { - result: Err(_), .. - }) => { + SwarmEvent::Behaviour(PingEvent { result: Err(_), .. }) => { count1 += 1; } - SwarmEvent::ConnectionClosed { .. } => { - return count1 - } + SwarmEvent::ConnectionClosed { .. } => return count1, _ => {} } } @@ -164,18 +163,15 @@ fn max_failures() { loop { match swarm2.select_next_some().await { SwarmEvent::Behaviour(PingEvent { - result: Ok(PingSuccess::Ping { .. }), .. + result: Ok(PingSuccess::Ping { .. }), + .. }) => { count2 = 0; // there may be an occasional success } - SwarmEvent::Behaviour(PingEvent { - result: Err(_), .. - }) => { + SwarmEvent::Behaviour(PingEvent { result: Err(_), .. }) => { count2 += 1; } - SwarmEvent::ConnectionClosed { .. } => { - return count2 - } + SwarmEvent::ConnectionClosed { .. } => return count2, _ => {} } } @@ -186,16 +182,24 @@ fn max_failures() { assert_eq!(u8::max(count1, count2), max_failures.get() - 1); } - QuickCheck::new().tests(10).quickcheck(prop as fn(_,_)) + QuickCheck::new().tests(10).quickcheck(prop as fn(_, _)) } #[test] fn unsupported_doesnt_fail() { let (peer1_id, trans) = mk_transport(MuxerChoice::Mplex); - let mut swarm1 = Swarm::new(trans, DummyBehaviour::with_keep_alive(KeepAlive::Yes), peer1_id.clone()); + let mut swarm1 = Swarm::new( + trans, + DummyBehaviour::with_keep_alive(KeepAlive::Yes), + peer1_id.clone(), + ); let (peer2_id, trans) = mk_transport(MuxerChoice::Mplex); - let mut swarm2 = Swarm::new(trans, Ping::new(PingConfig::new().with_keep_alive(true)), peer2_id.clone()); + let mut swarm2 = Swarm::new( + trans, + Ping::new(PingConfig::new().with_keep_alive(true)), + peer2_id.clone(), + ); let (mut tx, mut rx) = mpsc::channel::(1); @@ -217,7 +221,8 @@ fn unsupported_doesnt_fail() { loop { match swarm2.select_next_some().await { SwarmEvent::Behaviour(PingEvent { - result: Err(PingFailure::Unsupported), .. + result: Err(PingFailure::Unsupported), + .. }) => { swarm2.disconnect_peer_id(peer1_id).unwrap(); } @@ -235,25 +240,24 @@ fn unsupported_doesnt_fail() { result.expect("node with ping should not fail connection due to unsupported protocol"); } - -fn mk_transport(muxer: MuxerChoice) -> ( - PeerId, - transport::Boxed<(PeerId, StreamMuxerBox)> -) { +fn mk_transport(muxer: MuxerChoice) -> (PeerId, transport::Boxed<(PeerId, StreamMuxerBox)>) { let id_keys = identity::Keypair::generate_ed25519(); let peer_id = id_keys.public().to_peer_id(); - let noise_keys = noise::Keypair::::new().into_authentic(&id_keys).unwrap(); - (peer_id, TcpConfig::new() - .nodelay(true) - .upgrade(upgrade::Version::V1) - .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) - .multiplex(match muxer { - MuxerChoice::Yamux => - upgrade::EitherUpgrade::A(yamux::YamuxConfig::default()), - MuxerChoice::Mplex => - upgrade::EitherUpgrade::B(mplex::MplexConfig::default()), - }) - .boxed()) + let noise_keys = noise::Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); + ( + peer_id, + TcpConfig::new() + .nodelay(true) + .upgrade(upgrade::Version::V1) + .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) + .multiplex(match muxer { + MuxerChoice::Yamux => upgrade::EitherUpgrade::A(yamux::YamuxConfig::default()), + MuxerChoice::Mplex => upgrade::EitherUpgrade::B(mplex::MplexConfig::default()), + }) + .boxed(), + ) } #[derive(Debug, Copy, Clone)] diff --git a/protocols/relay/build.rs b/protocols/relay/build.rs index cd7bd3deef6..c3a7d4bd823 100644 --- a/protocols/relay/build.rs +++ b/protocols/relay/build.rs @@ -19,5 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/message.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/message.proto"], &["src"]).unwrap(); } diff --git a/protocols/relay/src/behaviour.rs b/protocols/relay/src/behaviour.rs index 78e9e5d8d66..9b17eca2c51 100644 --- a/protocols/relay/src/behaviour.rs +++ b/protocols/relay/src/behaviour.rs @@ -303,7 +303,7 @@ impl NetworkBehaviour for Relay { fn inject_dial_failure(&mut self, peer_id: &PeerId) { if let Entry::Occupied(o) = self.listeners.entry(*peer_id) { - if matches!(o.get(), RelayListener::Connecting{ .. }) { + if matches!(o.get(), RelayListener::Connecting { .. }) { // By removing the entry, the channel to the listener is dropped and thus the // listener is notified that dialing the relay failed. o.remove_entry(); diff --git a/protocols/relay/src/protocol/incoming_dst_req.rs b/protocols/relay/src/protocol/incoming_dst_req.rs index b3b0ded9de6..d68a15121f5 100644 --- a/protocols/relay/src/protocol/incoming_dst_req.rs +++ b/protocols/relay/src/protocol/incoming_dst_req.rs @@ -23,8 +23,8 @@ use crate::protocol::Peer; use asynchronous_codec::{Framed, FramedParts}; use bytes::BytesMut; -use futures::{future::BoxFuture, prelude::*}; use futures::channel::oneshot; +use futures::{future::BoxFuture, prelude::*}; use libp2p_core::{Multiaddr, PeerId}; use libp2p_swarm::NegotiatedSubstream; use prost::Message; @@ -47,8 +47,7 @@ pub struct IncomingDstReq { src: Peer, } -impl IncomingDstReq -{ +impl IncomingDstReq { /// Creates a `IncomingDstReq`. pub(crate) fn new(stream: Framed, src: Peer) -> Self { IncomingDstReq { @@ -73,7 +72,10 @@ impl IncomingDstReq /// stream then points to the source (as retreived with `src_id()` and `src_addrs()`). pub fn accept( self, - ) -> BoxFuture<'static, Result<(PeerId, super::Connection, oneshot::Receiver<()>), IncomingDstReqError>> { + ) -> BoxFuture< + 'static, + Result<(PeerId, super::Connection, oneshot::Receiver<()>), IncomingDstReqError>, + > { let IncomingDstReq { mut stream, src } = self; let msg = CircuitRelay { r#type: Some(circuit_relay::Type::Status.into()), @@ -101,7 +103,11 @@ impl IncomingDstReq let (tx, rx) = oneshot::channel(); - Ok((src.peer_id, super::Connection::new(read_buffer.freeze(), io, tx), rx)) + Ok(( + src.peer_id, + super::Connection::new(read_buffer.freeze(), io, tx), + rx, + )) } .boxed() } diff --git a/protocols/relay/src/protocol/incoming_relay_req.rs b/protocols/relay/src/protocol/incoming_relay_req.rs index 6f585db2854..948a2281f5b 100644 --- a/protocols/relay/src/protocol/incoming_relay_req.rs +++ b/protocols/relay/src/protocol/incoming_relay_req.rs @@ -23,7 +23,7 @@ use crate::message_proto::{circuit_relay, circuit_relay::Status, CircuitRelay}; use crate::protocol::Peer; use asynchronous_codec::{Framed, FramedParts}; -use bytes::{BytesMut, Bytes}; +use bytes::{Bytes, BytesMut}; use futures::channel::oneshot; use futures::future::BoxFuture; use futures::prelude::*; @@ -50,8 +50,7 @@ pub struct IncomingRelayReq { _notifier: oneshot::Sender<()>, } -impl IncomingRelayReq -{ +impl IncomingRelayReq { /// Creates a [`IncomingRelayReq`] as well as a Future that resolves once the /// [`IncomingRelayReq`] is dropped. pub(crate) fn new( diff --git a/protocols/relay/src/protocol/outgoing_dst_req.rs b/protocols/relay/src/protocol/outgoing_dst_req.rs index 7cffb1a1d96..181e31ef4a8 100644 --- a/protocols/relay/src/protocol/outgoing_dst_req.rs +++ b/protocols/relay/src/protocol/outgoing_dst_req.rs @@ -27,7 +27,7 @@ use futures::prelude::*; use libp2p_core::{upgrade, Multiaddr, PeerId}; use libp2p_swarm::NegotiatedSubstream; use prost::Message; -use std::{fmt, error, iter}; +use std::{error, fmt, iter}; use unsigned_varint::codec::UviBytes; /// Ask the remote to become a destination. The upgrade succeeds if the remote accepts, and fails @@ -96,14 +96,9 @@ impl upgrade::OutboundUpgrade for OutgoingDstReq { async move { substream.send(std::io::Cursor::new(self.message)).await?; - let msg = - substream - .next() - .await - .ok_or_else(|| OutgoingDstReqError::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - )))??; + let msg = substream.next().await.ok_or_else(|| { + OutgoingDstReqError::Io(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "")) + })??; let msg = std::io::Cursor::new(msg); let CircuitRelay { diff --git a/protocols/relay/src/protocol/outgoing_relay_req.rs b/protocols/relay/src/protocol/outgoing_relay_req.rs index a34f10eba26..a9d450b04d7 100644 --- a/protocols/relay/src/protocol/outgoing_relay_req.rs +++ b/protocols/relay/src/protocol/outgoing_relay_req.rs @@ -103,14 +103,12 @@ impl upgrade::OutboundUpgrade for OutgoingRelayReq { async move { substream.send(std::io::Cursor::new(encoded)).await?; - let msg = - substream - .next() - .await - .ok_or_else(|| OutgoingRelayReqError::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - )))??; + let msg = substream.next().await.ok_or_else(|| { + OutgoingRelayReqError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "", + )) + })??; let msg = std::io::Cursor::new(msg); let CircuitRelay { diff --git a/protocols/relay/src/transport.rs b/protocols/relay/src/transport.rs index b3729c0582f..b410faf0af4 100644 --- a/protocols/relay/src/transport.rs +++ b/protocols/relay/src/transport.rs @@ -402,7 +402,7 @@ impl Stream for RelayListener { stream, src_peer_id, relay_addr, - relay_peer_id: _ + relay_peer_id: _, })) => { return Poll::Ready(Some(Ok(ListenerEvent::Upgrade { upgrade: RelayedListenerUpgrade::Relayed(Some(stream)), diff --git a/protocols/relay/tests/lib.rs b/protocols/relay/tests/lib.rs index 86f2128001e..0829ec87d7b 100644 --- a/protocols/relay/tests/lib.rs +++ b/protocols/relay/tests/lib.rs @@ -35,7 +35,10 @@ use libp2p_ping::{Ping, PingConfig, PingEvent}; use libp2p_plaintext::PlainText2Config; use libp2p_relay::{Relay, RelayConfig}; use libp2p_swarm::protocols_handler::KeepAlive; -use libp2p_swarm::{DummyBehaviour, NetworkBehaviour, NetworkBehaviourAction, NetworkBehaviourEventProcess, PollParameters, Swarm, SwarmEvent}; +use libp2p_swarm::{ + DummyBehaviour, NetworkBehaviour, NetworkBehaviourAction, NetworkBehaviourEventProcess, + PollParameters, Swarm, SwarmEvent, +}; use std::task::{Context, Poll}; use std::time::Duration; use void::Void; @@ -388,9 +391,9 @@ fn src_try_connect_to_offline_dst() { loop { match src_swarm.select_next_some().await { - SwarmEvent::UnreachableAddr { address, peer_id, .. } - if address == dst_addr_via_relay => - { + SwarmEvent::UnreachableAddr { + address, peer_id, .. + } if address == dst_addr_via_relay => { assert_eq!(peer_id, dst_peer_id); break; } @@ -445,9 +448,9 @@ fn src_try_connect_to_unsupported_dst() { loop { match src_swarm.select_next_some().await { - SwarmEvent::UnreachableAddr { address, peer_id, .. } - if address == dst_addr_via_relay => - { + SwarmEvent::UnreachableAddr { + address, peer_id, .. + } if address == dst_addr_via_relay => { assert_eq!(peer_id, dst_peer_id); break; } @@ -495,10 +498,11 @@ fn src_try_connect_to_offline_dst_via_offline_relay() { // Source Node fail to reach Destination Node due to failure reaching Relay. match src_swarm.select_next_some().await { - SwarmEvent::UnreachableAddr { address, peer_id, .. } - if address == dst_addr_via_relay => { - assert_eq!(peer_id, dst_peer_id); - } + SwarmEvent::UnreachableAddr { + address, peer_id, .. + } if address == dst_addr_via_relay => { + assert_eq!(peer_id, dst_peer_id); + } e => panic!("{:?}", e), } }); @@ -582,11 +586,13 @@ fn firewalled_src_discover_firewalled_dst_via_kad_and_connect_to_dst_via_routabl let query_id = dst_swarm.behaviour_mut().kad.bootstrap().unwrap(); loop { match dst_swarm.select_next_some().await { - SwarmEvent::Behaviour(CombinedEvent::Kad(KademliaEvent::OutboundQueryCompleted { - id, - result: QueryResult::Bootstrap(Ok(_)), - .. - })) if query_id == id => { + SwarmEvent::Behaviour(CombinedEvent::Kad( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::Bootstrap(Ok(_)), + .. + }, + )) if query_id == id => { if dst_swarm.behaviour_mut().kad.iter_queries().count() == 0 { break; } @@ -660,11 +666,13 @@ fn firewalled_src_discover_firewalled_dst_via_kad_and_connect_to_dst_via_routabl SwarmEvent::Dialing(peer_id) if peer_id == relay_peer_id || peer_id == dst_peer_id => {} SwarmEvent::Behaviour(CombinedEvent::Ping(_)) => {} - SwarmEvent::Behaviour(CombinedEvent::Kad(KademliaEvent::OutboundQueryCompleted { - id, - result: QueryResult::GetClosestPeers(Ok(GetClosestPeersOk { .. })), - .. - })) if id == query_id => { + SwarmEvent::Behaviour(CombinedEvent::Kad( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::GetClosestPeers(Ok(GetClosestPeersOk { .. })), + .. + }, + )) if id == query_id => { tries += 1; if tries > 300 { panic!("Too many retries."); @@ -929,8 +937,12 @@ fn yield_incoming_connection_through_correct_listener() { relay_3_swarm.listen_on(relay_3_addr.clone()).unwrap(); spawn_swarm_on_pool(&pool, relay_3_swarm); - let dst_listener_via_relay_1 = dst_swarm.listen_on(relay_1_addr_incl_circuit.clone()).unwrap(); - let dst_listener_via_relay_2 = dst_swarm.listen_on(relay_2_addr_incl_circuit.clone()).unwrap(); + let dst_listener_via_relay_1 = dst_swarm + .listen_on(relay_1_addr_incl_circuit.clone()) + .unwrap(); + let dst_listener_via_relay_2 = dst_swarm + .listen_on(relay_2_addr_incl_circuit.clone()) + .unwrap(); // Listen on own address in order for relay 3 to be able to connect to destination node. let dst_listener = dst_swarm.listen_on(dst_addr.clone()).unwrap(); @@ -952,11 +964,15 @@ fn yield_incoming_connection_through_correct_listener() { SwarmEvent::NewListenAddr { address, listener_id, - } if listener_id == dst_listener_via_relay_2 => assert_eq!(address, relay_2_addr_incl_circuit), + } if listener_id == dst_listener_via_relay_2 => { + assert_eq!(address, relay_2_addr_incl_circuit) + } SwarmEvent::NewListenAddr { address, listener_id, - } if listener_id == dst_listener_via_relay_1 => assert_eq!(address, relay_1_addr_incl_circuit), + } if listener_id == dst_listener_via_relay_1 => { + assert_eq!(address, relay_1_addr_incl_circuit) + } SwarmEvent::NewListenAddr { address, listener_id, @@ -1077,7 +1093,11 @@ fn yield_incoming_connection_through_correct_listener() { pool.run_until(async { loop { match dst_swarm.select_next_some().await { - SwarmEvent::NewListenAddr { address, .. } if address == Protocol::P2pCircuit.into() => break, + SwarmEvent::NewListenAddr { address, .. } + if address == Protocol::P2pCircuit.into() => + { + break + } SwarmEvent::Behaviour(CombinedEvent::Ping(_)) => {} SwarmEvent::Behaviour(CombinedEvent::Kad(KademliaEvent::RoutingUpdated { .. @@ -1325,7 +1345,11 @@ fn build_keep_alive_only_swarm() -> Swarm { .multiplex(libp2p_yamux::YamuxConfig::default()) .boxed(); - Swarm::new(transport, DummyBehaviour::with_keep_alive(KeepAlive::Yes), local_peer_id) + Swarm::new( + transport, + DummyBehaviour::with_keep_alive(KeepAlive::Yes), + local_peer_id, + ) } fn spawn_swarm_on_pool(pool: &LocalPool, mut swarm: Swarm) { diff --git a/protocols/request-response/src/codec.rs b/protocols/request-response/src/codec.rs index bbb708081dc..5345d200843 100644 --- a/protocols/request-response/src/codec.rs +++ b/protocols/request-response/src/codec.rs @@ -38,30 +38,43 @@ pub trait RequestResponseCodec { /// Reads a request from the given I/O stream according to the /// negotiated protocol. - async fn read_request(&mut self, protocol: &Self::Protocol, io: &mut T) - -> io::Result + async fn read_request( + &mut self, + protocol: &Self::Protocol, + io: &mut T, + ) -> io::Result where T: AsyncRead + Unpin + Send; /// Reads a response from the given I/O stream according to the /// negotiated protocol. - async fn read_response(&mut self, protocol: &Self::Protocol, io: &mut T) - -> io::Result + async fn read_response( + &mut self, + protocol: &Self::Protocol, + io: &mut T, + ) -> io::Result where T: AsyncRead + Unpin + Send; /// Writes a request to the given I/O stream according to the /// negotiated protocol. - async fn write_request(&mut self, protocol: &Self::Protocol, io: &mut T, req: Self::Request) - -> io::Result<()> + async fn write_request( + &mut self, + protocol: &Self::Protocol, + io: &mut T, + req: Self::Request, + ) -> io::Result<()> where T: AsyncWrite + Unpin + Send; /// Writes a response to the given I/O stream according to the /// negotiated protocol. - async fn write_response(&mut self, protocol: &Self::Protocol, io: &mut T, res: Self::Response) - -> io::Result<()> + async fn write_response( + &mut self, + protocol: &Self::Protocol, + io: &mut T, + res: Self::Response, + ) -> io::Result<()> where T: AsyncWrite + Unpin + Send; } - diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index ddb9f042dd4..ee2550df183 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -20,37 +20,29 @@ mod protocol; -use crate::{EMPTY_QUEUE_SHRINK_THRESHOLD, RequestId}; use crate::codec::RequestResponseCodec; +use crate::{RequestId, EMPTY_QUEUE_SHRINK_THRESHOLD}; -pub use protocol::{RequestProtocol, ResponseProtocol, ProtocolSupport}; +pub use protocol::{ProtocolSupport, RequestProtocol, ResponseProtocol}; -use futures::{ - channel::oneshot, - future::BoxFuture, - prelude::*, - stream::FuturesUnordered -}; -use libp2p_core::{ - upgrade::{UpgradeError, NegotiationError}, -}; +use futures::{channel::oneshot, future::BoxFuture, prelude::*, stream::FuturesUnordered}; +use libp2p_core::upgrade::{NegotiationError, UpgradeError}; use libp2p_swarm::{ - SubstreamProtocol, protocols_handler::{ - KeepAlive, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, - } + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, + }, + SubstreamProtocol, }; use smallvec::SmallVec; use std::{ collections::VecDeque, - fmt, - io, - sync::{atomic::{AtomicU64, Ordering}, Arc}, + fmt, io, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, time::Duration, - task::{Context, Poll} }; use wasm_timer::Instant; @@ -79,12 +71,19 @@ where /// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`. outbound: VecDeque>, /// Inbound upgrades waiting for the incoming request. - inbound: FuturesUnordered), - oneshot::Canceled - >>>, - inbound_request_id: Arc + inbound: FuturesUnordered< + BoxFuture< + 'static, + Result< + ( + (RequestId, TCodec::Request), + oneshot::Sender, + ), + oneshot::Canceled, + >, + >, + >, + inbound_request_id: Arc, } impl RequestResponseHandler @@ -96,7 +95,7 @@ where codec: TCodec, keep_alive_timeout: Duration, substream_timeout: Duration, - inbound_request_id: Arc + inbound_request_id: Arc, ) -> Self { Self { inbound_protocols, @@ -108,7 +107,7 @@ where inbound: FuturesUnordered::new(), pending_events: VecDeque::new(), pending_error: None, - inbound_request_id + inbound_request_id, } } } @@ -117,18 +116,18 @@ where #[doc(hidden)] pub enum RequestResponseHandlerEvent where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { /// A request has been received. Request { request_id: RequestId, request: TCodec::Request, - sender: oneshot::Sender + sender: oneshot::Sender, }, /// A response has been received. Response { request_id: RequestId, - response: TCodec::Response + response: TCodec::Response, }, /// A response to an inbound request has been sent. ResponseSent(RequestId), @@ -150,28 +149,43 @@ where impl fmt::Debug for RequestResponseHandlerEvent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - RequestResponseHandlerEvent::Request { request_id, request: _, sender: _ } => f.debug_struct("RequestResponseHandlerEvent::Request") + RequestResponseHandlerEvent::Request { + request_id, + request: _, + sender: _, + } => f + .debug_struct("RequestResponseHandlerEvent::Request") .field("request_id", request_id) .finish(), - RequestResponseHandlerEvent::Response { request_id, response: _ } => f.debug_struct("RequestResponseHandlerEvent::Response") + RequestResponseHandlerEvent::Response { + request_id, + response: _, + } => f + .debug_struct("RequestResponseHandlerEvent::Response") .field("request_id", request_id) .finish(), - RequestResponseHandlerEvent::ResponseSent(request_id) => f.debug_tuple("RequestResponseHandlerEvent::ResponseSent") + RequestResponseHandlerEvent::ResponseSent(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::ResponseSent") .field(request_id) .finish(), - RequestResponseHandlerEvent::ResponseOmission(request_id) => f.debug_tuple("RequestResponseHandlerEvent::ResponseOmission") + RequestResponseHandlerEvent::ResponseOmission(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::ResponseOmission") .field(request_id) .finish(), - RequestResponseHandlerEvent::OutboundTimeout(request_id) => f.debug_tuple("RequestResponseHandlerEvent::OutboundTimeout") + RequestResponseHandlerEvent::OutboundTimeout(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::OutboundTimeout") .field(request_id) .finish(), - RequestResponseHandlerEvent::OutboundUnsupportedProtocols(request_id) => f.debug_tuple("RequestResponseHandlerEvent::OutboundUnsupportedProtocols") + RequestResponseHandlerEvent::OutboundUnsupportedProtocols(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::OutboundUnsupportedProtocols") .field(request_id) .finish(), - RequestResponseHandlerEvent::InboundTimeout(request_id) => f.debug_tuple("RequestResponseHandlerEvent::InboundTimeout") + RequestResponseHandlerEvent::InboundTimeout(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::InboundTimeout") .field(request_id) .finish(), - RequestResponseHandlerEvent::InboundUnsupportedProtocols(request_id) => f.debug_tuple("RequestResponseHandlerEvent::InboundUnsupportedProtocols") + RequestResponseHandlerEvent::InboundUnsupportedProtocols(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::InboundUnsupportedProtocols") .field(request_id) .finish(), } @@ -212,28 +226,25 @@ where codec: self.codec.clone(), request_sender: rq_send, response_receiver: rs_recv, - request_id + request_id, }; // The handler waits for the request to come in. It then emits // `RequestResponseHandlerEvent::Request` together with a // `ResponseChannel`. - self.inbound.push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed()); + self.inbound + .push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed()); SubstreamProtocol::new(proto, request_id).with_timeout(self.substream_timeout) } - fn inject_fully_negotiated_inbound( - &mut self, - sent: bool, - request_id: RequestId - ) { + fn inject_fully_negotiated_inbound(&mut self, sent: bool, request_id: RequestId) { if sent { - self.pending_events.push_back( - RequestResponseHandlerEvent::ResponseSent(request_id)) + self.pending_events + .push_back(RequestResponseHandlerEvent::ResponseSent(request_id)) } else { - self.pending_events.push_back( - RequestResponseHandlerEvent::ResponseOmission(request_id)) + self.pending_events + .push_back(RequestResponseHandlerEvent::ResponseOmission(request_id)) } } @@ -242,9 +253,10 @@ where response: TCodec::Response, request_id: RequestId, ) { - self.pending_events.push_back( - RequestResponseHandlerEvent::Response { - request_id, response + self.pending_events + .push_back(RequestResponseHandlerEvent::Response { + request_id, + response, }); } @@ -260,8 +272,8 @@ where ) { match error { ProtocolsHandlerUpgrErr::Timeout => { - self.pending_events.push_back( - RequestResponseHandlerEvent::OutboundTimeout(info)); + self.pending_events + .push_back(RequestResponseHandlerEvent::OutboundTimeout(info)); } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { // The remote merely doesn't support the protocol(s) we requested. @@ -270,7 +282,8 @@ where // An event is reported to permit user code to react to the fact that // the remote peer does not support the requested protocol(s). self.pending_events.push_back( - RequestResponseHandlerEvent::OutboundUnsupportedProtocols(info)); + RequestResponseHandlerEvent::OutboundUnsupportedProtocols(info), + ); } _ => { // Anything else is considered a fatal error or misbehaviour of @@ -283,12 +296,12 @@ where fn inject_listen_upgrade_error( &mut self, info: RequestId, - error: ProtocolsHandlerUpgrErr + error: ProtocolsHandlerUpgrErr, ) { match error { - ProtocolsHandlerUpgrErr::Timeout => { - self.pending_events.push_back(RequestResponseHandlerEvent::InboundTimeout(info)) - } + ProtocolsHandlerUpgrErr::Timeout => self + .pending_events + .push_back(RequestResponseHandlerEvent::InboundTimeout(info)), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { // The local peer merely doesn't support the protocol(s) requested. // This is no reason to close the connection, which may @@ -296,7 +309,8 @@ where // An event is reported to permit user code to react to the fact that // the local peer does not support the requested protocol(s). self.pending_events.push_back( - RequestResponseHandlerEvent::InboundUnsupportedProtocols(info)); + RequestResponseHandlerEvent::InboundUnsupportedProtocols(info), + ); } _ => { // Anything else is considered a fatal error or misbehaviour of @@ -313,18 +327,17 @@ where fn poll( &mut self, cx: &mut Context<'_>, - ) -> Poll< - ProtocolsHandlerEvent, RequestId, Self::OutEvent, Self::Error>, - > { + ) -> Poll, RequestId, Self::OutEvent, Self::Error>> + { // Check for a pending (fatal) error. if let Some(err) = self.pending_error.take() { // The handler will not be polled again by the `Swarm`. - return Poll::Ready(ProtocolsHandlerEvent::Close(err)) + return Poll::Ready(ProtocolsHandlerEvent::Close(err)); } // Drain pending events. if let Some(event) = self.pending_events.pop_front() { - return Poll::Ready(ProtocolsHandlerEvent::Custom(event)) + return Poll::Ready(ProtocolsHandlerEvent::Custom(event)); } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { self.pending_events.shrink_to_fit(); } @@ -337,8 +350,11 @@ where self.keep_alive = KeepAlive::Yes; return Poll::Ready(ProtocolsHandlerEvent::Custom( RequestResponseHandlerEvent::Request { - request_id: id, request: rq, sender: rs_sender - })) + request_id: id, + request: rq, + sender: rs_sender, + }, + )); } Err(oneshot::Canceled) => { // The inbound upgrade has errored or timed out reading @@ -351,12 +367,10 @@ where // Emit outbound requests. if let Some(request) = self.outbound.pop_front() { let info = request.request_id; - return Poll::Ready( - ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(request, info) - .with_timeout(self.substream_timeout) - }, - ) + return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(request, info) + .with_timeout(self.substream_timeout), + }); } debug_assert!(self.outbound.is_empty()); diff --git a/protocols/request-response/src/handler/protocol.rs b/protocols/request-response/src/handler/protocol.rs index cede827df27..dda4ee00d2d 100644 --- a/protocols/request-response/src/handler/protocol.rs +++ b/protocols/request-response/src/handler/protocol.rs @@ -23,8 +23,8 @@ //! receives a request and sends a response, whereas the //! outbound upgrade send a request and receives a response. -use crate::RequestId; use crate::codec::RequestResponseCodec; +use crate::RequestId; use futures::{channel::oneshot, future::BoxFuture, prelude::*}; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; @@ -40,7 +40,7 @@ pub enum ProtocolSupport { /// The protocol is only supported for outbound requests. Outbound, /// The protocol is supported for inbound and outbound requests. - Full + Full, } impl ProtocolSupport { @@ -67,19 +67,18 @@ impl ProtocolSupport { #[derive(Debug)] pub struct ResponseProtocol where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { pub(crate) codec: TCodec, pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>, pub(crate) request_sender: oneshot::Sender<(RequestId, TCodec::Request)>, pub(crate) response_receiver: oneshot::Receiver, - pub(crate) request_id: RequestId - + pub(crate) request_id: RequestId, } impl UpgradeInfo for ResponseProtocol where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { type Info = TCodec::Protocol; type InfoIter = smallvec::IntoIter<[Self::Info; 2]>; @@ -97,7 +96,11 @@ where type Error = io::Error; type Future = BoxFuture<'static, Result>; - fn upgrade_inbound(mut self, mut io: NegotiatedSubstream, protocol: Self::Info) -> Self::Future { + fn upgrade_inbound( + mut self, + mut io: NegotiatedSubstream, + protocol: Self::Info, + ) -> Self::Future { async move { let read = self.codec.read_request(&protocol, &mut io); let request = read.await?; @@ -129,7 +132,7 @@ where /// Sends a request and receives a response. pub struct RequestProtocol where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { pub(crate) codec: TCodec, pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>, @@ -150,7 +153,7 @@ where impl UpgradeInfo for RequestProtocol where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { type Info = TCodec::Protocol; type InfoIter = smallvec::IntoIter<[Self::Info; 2]>; @@ -168,7 +171,11 @@ where type Error = io::Error; type Future = BoxFuture<'static, Result>; - fn upgrade_outbound(mut self, mut io: NegotiatedSubstream, protocol: Self::Info) -> Self::Future { + fn upgrade_outbound( + mut self, + mut io: NegotiatedSubstream, + protocol: Self::Info, + ) -> Self::Future { async move { let write = self.codec.write_request(&protocol, &mut io, self.request); write.await?; @@ -176,6 +183,7 @@ where let read = self.codec.read_response(&protocol, &mut io); let response = read.await?; Ok(response) - }.boxed() + } + .boxed() } } diff --git a/protocols/request-response/src/lib.rs b/protocols/request-response/src/lib.rs index 7e5fd58c5c1..a2277e4c8df 100644 --- a/protocols/request-response/src/lib.rs +++ b/protocols/request-response/src/lib.rs @@ -60,38 +60,23 @@ pub mod codec; pub mod handler; pub mod throttled; -pub use codec::{RequestResponseCodec, ProtocolName}; +pub use codec::{ProtocolName, RequestResponseCodec}; pub use handler::ProtocolSupport; pub use throttled::Throttled; -use futures::{ - channel::oneshot, -}; -use handler::{ - RequestProtocol, - RequestResponseHandler, - RequestResponseHandlerEvent, -}; -use libp2p_core::{ - ConnectedPoint, - Multiaddr, - PeerId, - connection::ConnectionId, -}; +use futures::channel::oneshot; +use handler::{RequestProtocol, RequestResponseHandler, RequestResponseHandlerEvent}; +use libp2p_core::{connection::ConnectionId, ConnectedPoint, Multiaddr, PeerId}; use libp2p_swarm::{ - DialPeerCondition, - NetworkBehaviour, - NetworkBehaviourAction, - NotifyHandler, - PollParameters, + DialPeerCondition, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters, }; use smallvec::SmallVec; use std::{ collections::{HashMap, HashSet, VecDeque}, fmt, - time::Duration, sync::{atomic::AtomicU64, Arc}, - task::{Context, Poll} + task::{Context, Poll}, + time::Duration, }; /// An inbound request or response. @@ -117,7 +102,7 @@ pub enum RequestResponseMessage /// The peer who sent the message. peer: PeerId, /// The incoming message. - message: RequestResponseMessage + message: RequestResponseMessage, }, /// An outbound request failed. OutboundFailure { @@ -186,8 +171,12 @@ impl fmt::Display for OutboundFailure { match self { OutboundFailure::DialFailure => write!(f, "Failed to dial the requested peer"), OutboundFailure::Timeout => write!(f, "Timeout while waiting for a response"), - OutboundFailure::ConnectionClosed => write!(f, "Connection was closed before a response was received"), - OutboundFailure::UnsupportedProtocols => write!(f, "The remote supports none of the requested protocols") + OutboundFailure::ConnectionClosed => { + write!(f, "Connection was closed before a response was received") + } + OutboundFailure::UnsupportedProtocols => { + write!(f, "The remote supports none of the requested protocols") + } } } } @@ -217,10 +206,20 @@ pub enum InboundFailure { impl fmt::Display for InboundFailure { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - InboundFailure::Timeout => write!(f, "Timeout while receiving request or sending response"), - InboundFailure::ConnectionClosed => write!(f, "Connection was closed before a response could be sent"), - InboundFailure::UnsupportedProtocols => write!(f, "The local peer supports none of the protocols requested by the remote"), - InboundFailure::ResponseOmission => write!(f, "The response channel was dropped without sending a response to the remote") + InboundFailure::Timeout => { + write!(f, "Timeout while receiving request or sending response") + } + InboundFailure::ConnectionClosed => { + write!(f, "Connection was closed before a response could be sent") + } + InboundFailure::UnsupportedProtocols => write!( + f, + "The local peer supports none of the protocols requested by the remote" + ), + InboundFailure::ResponseOmission => write!( + f, + "The response channel was dropped without sending a response to the remote" + ), } } } @@ -322,7 +321,9 @@ where pending_events: VecDeque< NetworkBehaviourAction< RequestProtocol, - RequestResponseEvent>>, + RequestResponseEvent, + >, + >, /// The currently connected peers, their pending outbound and inbound responses and their known, /// reachable addresses, if any. connected: HashMap>, @@ -341,7 +342,7 @@ where /// protocols, codec and configuration. pub fn new(codec: TCodec, protocols: I, cfg: RequestResponseConfig) -> Self where - I: IntoIterator + I: IntoIterator, { let mut inbound_protocols = SmallVec::new(); let mut outbound_protocols = SmallVec::new(); @@ -375,7 +376,7 @@ where where I: IntoIterator, TCodec: Send, - TCodec::Protocol: Sync + TCodec::Protocol: Sync, { Throttled::new(c, protos, cfg) } @@ -402,11 +403,15 @@ where }; if let Some(request) = self.try_send_request(peer, request) { - self.pending_events.push_back(NetworkBehaviourAction::DialPeer { - peer_id: *peer, - condition: DialPeerCondition::Disconnected, - }); - self.pending_outbound_requests.entry(*peer).or_default().push(request); + self.pending_events + .push_back(NetworkBehaviourAction::DialPeer { + peer_id: *peer, + condition: DialPeerCondition::Disconnected, + }); + self.pending_outbound_requests + .entry(*peer) + .or_default() + .push(request); } request_id @@ -423,9 +428,11 @@ where /// /// The provided `ResponseChannel` is obtained from an inbound /// [`RequestResponseMessage::Request`]. - pub fn send_response(&mut self, ch: ResponseChannel, rs: TCodec::Response) - -> Result<(), TCodec::Response> - { + pub fn send_response( + &mut self, + ch: ResponseChannel, + rs: TCodec::Response, + ) -> Result<(), TCodec::Response> { ch.sender.send(rs) } @@ -464,12 +471,19 @@ where /// pending, i.e. waiting for a response. pub fn is_pending_outbound(&self, peer: &PeerId, request_id: &RequestId) -> bool { // Check if request is already sent on established connection. - let est_conn = self.connected.get(peer) - .map(|cs| cs.iter().any(|c| c.pending_inbound_responses.contains(request_id))) + let est_conn = self + .connected + .get(peer) + .map(|cs| { + cs.iter() + .any(|c| c.pending_inbound_responses.contains(request_id)) + }) .unwrap_or(false); // Check if request is still pending to be sent. - let pen_conn = self.pending_outbound_requests.get(peer) - .map(|rps| rps.iter().any(|rp| {rp.request_id == *request_id})) + let pen_conn = self + .pending_outbound_requests + .get(peer) + .map(|rps| rps.iter().any(|rp| rp.request_id == *request_id)) .unwrap_or(false); est_conn || pen_conn @@ -479,8 +493,12 @@ where /// [`PeerId`] is still pending, i.e. waiting for a response by the local /// node through [`RequestResponse::send_response`]. pub fn is_pending_inbound(&self, peer: &PeerId, request_id: &RequestId) -> bool { - self.connected.get(peer) - .map(|cs| cs.iter().any(|c| c.pending_outbound_responses.contains(request_id))) + self.connected + .get(peer) + .map(|cs| { + cs.iter() + .any(|c| c.pending_outbound_responses.contains(request_id)) + }) .unwrap_or(false) } @@ -494,21 +512,24 @@ where /// Tries to send a request by queueing an appropriate event to be /// emitted to the `Swarm`. If the peer is not currently connected, /// the given request is return unchanged. - fn try_send_request(&mut self, peer: &PeerId, request: RequestProtocol) - -> Option> - { + fn try_send_request( + &mut self, + peer: &PeerId, + request: RequestProtocol, + ) -> Option> { if let Some(connections) = self.connected.get_mut(peer) { if connections.is_empty() { - return Some(request) + return Some(request); } let ix = (request.request_id.0 as usize) % connections.len(); let conn = &mut connections[ix]; conn.pending_inbound_responses.insert(request.request_id); - self.pending_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *peer, - handler: NotifyHandler::One(conn.id), - event: request - }); + self.pending_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *peer, + handler: NotifyHandler::One(conn.id), + event: request, + }); None } else { Some(request) @@ -554,9 +575,9 @@ where peer: &PeerId, connection: ConnectionId, ) -> Option<&mut Connection> { - self.connected.get_mut(peer).and_then(|connections| { - connections.iter_mut().find(|c| c.id == connection) - }) + self.connected + .get_mut(peer) + .and_then(|connections| connections.iter_mut().find(|c| c.id == connection)) } } @@ -573,7 +594,7 @@ where self.codec.clone(), self.config.connection_keep_alive, self.config.request_timeout, - self.next_inbound_id.clone() + self.next_inbound_id.clone(), ) } @@ -597,21 +618,35 @@ where } } - fn inject_connection_established(&mut self, peer: &PeerId, conn: &ConnectionId, endpoint: &ConnectedPoint) { + fn inject_connection_established( + &mut self, + peer: &PeerId, + conn: &ConnectionId, + endpoint: &ConnectedPoint, + ) { let address = match endpoint { ConnectedPoint::Dialer { address } => Some(address.clone()), - ConnectedPoint::Listener { .. } => None + ConnectedPoint::Listener { .. } => None, }; - self.connected.entry(*peer) + self.connected + .entry(*peer) .or_default() .push(Connection::new(*conn, address)); } - fn inject_connection_closed(&mut self, peer_id: &PeerId, conn: &ConnectionId, _: &ConnectedPoint) { - let connections = self.connected.get_mut(peer_id) + fn inject_connection_closed( + &mut self, + peer_id: &PeerId, + conn: &ConnectionId, + _: &ConnectedPoint, + ) { + let connections = self + .connected + .get_mut(peer_id) .expect("Expected some established connection to peer before closing."); - let connection = connections.iter() + let connection = connections + .iter() .position(|c| &c.id == conn) .map(|p: usize| connections.remove(p)) .expect("Expected connection to be established before closing."); @@ -621,24 +656,25 @@ where } for request_id in connection.pending_outbound_responses { - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::InboundFailure { - peer: *peer_id, - request_id, - error: InboundFailure::ConnectionClosed - } - )); - + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::InboundFailure { + peer: *peer_id, + request_id, + error: InboundFailure::ConnectionClosed, + }, + )); } for request_id in connection.pending_inbound_responses { - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::OutboundFailure { - peer: *peer_id, - request_id, - error: OutboundFailure::ConnectionClosed - } - )); + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::OutboundFailure { + peer: *peer_id, + request_id, + error: OutboundFailure::ConnectionClosed, + }, + )); } } @@ -655,13 +691,14 @@ where // another, concurrent dialing attempt ongoing. if let Some(pending) = self.pending_outbound_requests.remove(peer) { for request in pending { - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::OutboundFailure { - peer: *peer, - request_id: request.request_id, - error: OutboundFailure::DialFailure - } - )); + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::OutboundFailure { + peer: *peer, + request_id: request.request_id, + error: OutboundFailure::DialFailure, + }, + )); } } } @@ -673,49 +710,74 @@ where event: RequestResponseHandlerEvent, ) { match event { - RequestResponseHandlerEvent::Response { request_id, response } => { + RequestResponseHandlerEvent::Response { + request_id, + response, + } => { let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); debug_assert!( removed, "Expect request_id to be pending before receiving response.", ); - let message = RequestResponseMessage::Response { request_id, response }; - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::Message { peer, message })); + let message = RequestResponseMessage::Response { + request_id, + response, + }; + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::Message { peer, message }, + )); } - RequestResponseHandlerEvent::Request { request_id, request, sender } => { - let channel = ResponseChannel { request_id, peer, sender }; - let message = RequestResponseMessage::Request { request_id, request, channel }; - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::Message { peer, message } - )); + RequestResponseHandlerEvent::Request { + request_id, + request, + sender, + } => { + let channel = ResponseChannel { + request_id, + peer, + sender, + }; + let message = RequestResponseMessage::Request { + request_id, + request, + channel, + }; + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::Message { peer, message }, + )); match self.get_connection_mut(&peer, connection) { Some(connection) => { let inserted = connection.pending_outbound_responses.insert(request_id); debug_assert!(inserted, "Expect id of new request to be unknown."); - }, + } // Connection closed after `RequestResponseEvent::Request` has been emitted. None => { - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::InboundFailure { - peer, - request_id, - error: InboundFailure::ConnectionClosed - } - )); + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::InboundFailure { + peer, + request_id, + error: InboundFailure::ConnectionClosed, + }, + )); } } } RequestResponseHandlerEvent::ResponseSent(request_id) => { let removed = self.remove_pending_outbound_response(&peer, connection, request_id); - debug_assert!(removed, "Expect request_id to be pending before response is sent."); + debug_assert!( + removed, + "Expect request_id to be pending before response is sent." + ); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::ResponseSent { peer, request_id })); + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::ResponseSent { peer, request_id }, + )); } RequestResponseHandlerEvent::ResponseOmission(request_id) => { let removed = self.remove_pending_outbound_response(&peer, connection, request_id); @@ -724,25 +786,30 @@ where "Expect request_id to be pending before response is omitted.", ); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::InboundFailure { peer, request_id, - error: InboundFailure::ResponseOmission - })); + error: InboundFailure::ResponseOmission, + }, + )); } RequestResponseHandlerEvent::OutboundTimeout(request_id) => { let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); - debug_assert!(removed, "Expect request_id to be pending before request times out."); + debug_assert!( + removed, + "Expect request_id to be pending before request times out." + ); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::OutboundFailure { peer, request_id, error: OutboundFailure::Timeout, - })); + }, + )); } RequestResponseHandlerEvent::InboundTimeout(request_id) => { // Note: `RequestResponseHandlerEvent::InboundTimeout` is emitted both for timing @@ -751,13 +818,14 @@ where // not assert the request_id to be present before removing it. self.remove_pending_outbound_response(&peer, connection, request_id); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::InboundFailure { peer, request_id, error: InboundFailure::Timeout, - })); + }, + )); } RequestResponseHandlerEvent::OutboundUnsupportedProtocols(request_id) => { let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); @@ -766,35 +834,41 @@ where "Expect request_id to be pending before failing to connect.", ); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::OutboundFailure { peer, request_id, error: OutboundFailure::UnsupportedProtocols, - })); + }, + )); } RequestResponseHandlerEvent::InboundUnsupportedProtocols(request_id) => { // Note: No need to call `self.remove_pending_outbound_response`, // `RequestResponseHandlerEvent::Request` was never emitted for this request and // thus request was never added to `pending_outbound_responses`. - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::InboundFailure { peer, request_id, error: InboundFailure::UnsupportedProtocols, - })); + }, + )); } } } - fn poll(&mut self, _: &mut Context<'_>, _: &mut impl PollParameters) - -> Poll, + _: &mut impl PollParameters, + ) -> Poll< + NetworkBehaviourAction< RequestProtocol, - RequestResponseEvent - >> - { + RequestResponseEvent, + >, + > { if let Some(ev) = self.pending_events.pop_front() { return Poll::Ready(ev); } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { @@ -821,7 +895,7 @@ struct Connection { pending_outbound_responses: HashSet, /// Pending inbound responses for previously sent requests on this /// connection. - pending_inbound_responses: HashSet + pending_inbound_responses: HashSet, } impl Connection { diff --git a/protocols/request-response/src/throttled.rs b/protocols/request-response/src/throttled.rs index 611331f4f52..c882f41b211 100644 --- a/protocols/request-response/src/throttled.rs +++ b/protocols/request-response/src/throttled.rs @@ -36,22 +36,20 @@ mod codec; -use codec::{Codec, Message, ProtocolWrapper, Type}; +use super::{ + ProtocolSupport, RequestId, RequestResponse, RequestResponseCodec, RequestResponseConfig, + RequestResponseEvent, RequestResponseMessage, +}; use crate::handler::{RequestProtocol, RequestResponseHandler, RequestResponseHandlerEvent}; +use codec::{Codec, Message, ProtocolWrapper, Type}; use futures::ready; -use libp2p_core::{ConnectedPoint, connection::ConnectionId, Multiaddr, PeerId}; +use libp2p_core::{connection::ConnectionId, ConnectedPoint, Multiaddr, PeerId}; use libp2p_swarm::{NetworkBehaviour, NetworkBehaviourAction, PollParameters}; use lru::LruCache; -use std::{collections::{HashMap, HashSet, VecDeque}, task::{Context, Poll}}; use std::{cmp::max, num::NonZeroU16}; -use super::{ - ProtocolSupport, - RequestId, - RequestResponse, - RequestResponseCodec, - RequestResponseConfig, - RequestResponseEvent, - RequestResponseMessage, +use std::{ + collections::{HashMap, HashSet, VecDeque}, + task::{Context, Poll}, }; pub type ResponseChannel = super::ResponseChannel>; @@ -60,7 +58,7 @@ pub type ResponseChannel = super::ResponseChannel>; pub struct Throttled where C: RequestResponseCodec + Send, - C::Protocol: Sync + C::Protocol: Sync, { /// A random id used for logging. id: u32, @@ -77,7 +75,7 @@ where /// Pending events to report in `Throttled::poll`. events: VecDeque>>, /// The current credit ID. - next_grant_id: u64 + next_grant_id: u64, } /// Information about a credit grant that is sent to remote peers. @@ -89,7 +87,7 @@ struct Grant { request: RequestId, /// The credit given in this grant, i.e. the number of additional /// requests the remote is allowed to send. - credit: u16 + credit: u16, } /// Max. number of inbound requests that can be received. @@ -99,7 +97,7 @@ struct Limit { max_recv: NonZeroU16, /// The next receive limit which becomes active after /// the current limit has been reached. - next_max: NonZeroU16 + next_max: NonZeroU16, } impl Limit { @@ -111,7 +109,7 @@ impl Limit { // sender so we must not use `max` right away. Limit { max_recv: NonZeroU16::new(1).expect("1 > 0"), - next_max: max + next_max: max, } } @@ -191,7 +189,7 @@ impl PeerInfo { limit: recv_limit, remaining: 1, sent: HashSet::new(), - } + }, } } @@ -210,16 +208,18 @@ impl PeerInfo { impl Throttled where C: RequestResponseCodec + Send + Clone, - C::Protocol: Sync + C::Protocol: Sync, { /// Create a new throttled request-response behaviour. pub fn new(c: C, protos: I, cfg: RequestResponseConfig) -> Self where I: IntoIterator, C: Send, - C::Protocol: Sync + C::Protocol: Sync, { - let protos = protos.into_iter().map(|(p, ps)| (ProtocolWrapper::new(b"/t/1", p), ps)); + let protos = protos + .into_iter() + .map(|(p, ps)| (ProtocolWrapper::new(b"/t/1", p), ps)); Throttled::from(RequestResponse::new(Codec::new(c, 8192), protos, cfg)) } @@ -233,7 +233,7 @@ where default_limit: Limit::new(NonZeroU16::new(1).expect("1 > 0")), limit_overrides: HashMap::new(), events: VecDeque::new(), - next_grant_id: 0 + next_grant_id: 0, } } @@ -262,7 +262,10 @@ where /// Has the limit of outbound requests been reached for the given peer? pub fn can_send(&mut self, p: &PeerId) -> bool { - self.peer_info.get(p).map(|i| i.send_budget.remaining > 0).unwrap_or(true) + self.peer_info + .get(p) + .map(|i| i.send_budget.remaining > 0) + .unwrap_or(true) } /// Send a request to a peer. @@ -273,22 +276,30 @@ where pub fn send_request(&mut self, p: &PeerId, req: C::Request) -> Result { let connected = &mut self.peer_info; let disconnected = &mut self.offline_peer_info; - let remaining = - if let Some(info) = connected.get_mut(p).or_else(|| disconnected.get_mut(p)) { - if info.send_budget.remaining == 0 { - log::trace!("{:08x}: no more budget to send another request to {}", self.id, p); - return Err(req) - } - info.send_budget.remaining -= 1; - info.send_budget.remaining - } else { - let limit = self.limit_overrides.get(p).copied().unwrap_or(self.default_limit); - let mut info = PeerInfo::new(limit); - info.send_budget.remaining -= 1; - let remaining = info.send_budget.remaining; - self.offline_peer_info.put(*p, info); - remaining - }; + let remaining = if let Some(info) = connected.get_mut(p).or_else(|| disconnected.get_mut(p)) + { + if info.send_budget.remaining == 0 { + log::trace!( + "{:08x}: no more budget to send another request to {}", + self.id, + p + ); + return Err(req); + } + info.send_budget.remaining -= 1; + info.send_budget.remaining + } else { + let limit = self + .limit_overrides + .get(p) + .copied() + .unwrap_or(self.default_limit); + let mut info = PeerInfo::new(limit); + info.send_budget.remaining -= 1; + let remaining = info.send_budget.remaining; + self.offline_peer_info.put(*p, info); + remaining + }; let rid = self.behaviour.send_request(p, Message::request(req)); @@ -305,12 +316,20 @@ where /// Answer an inbound request with a response. /// /// See [`RequestResponse::send_response`] for details. - pub fn send_response(&mut self, ch: ResponseChannel, res: C::Response) - -> Result<(), C::Response> - { - log::trace!("{:08x}: sending response {} to peer {}", self.id, ch.request_id(), &ch.peer); + pub fn send_response( + &mut self, + ch: ResponseChannel, + res: C::Response, + ) -> Result<(), C::Response> { + log::trace!( + "{:08x}: sending response {} to peer {}", + self.id, + ch.request_id(), + &ch.peer + ); if let Some(info) = self.peer_info.get_mut(&ch.peer) { - if info.recv_budget.remaining == 0 { // need to send more credit to the remote peer + if info.recv_budget.remaining == 0 { + // need to send more credit to the remote peer let crd = info.recv_budget.limit.switch(); info.recv_budget.remaining = info.recv_budget.limit.max_recv.get(); self.send_credit(&ch.peer, crd); @@ -350,7 +369,6 @@ where self.behaviour.is_pending_outbound(p, r) } - /// Is the remote waiting for the local node to respond to the given /// request? /// @@ -365,8 +383,18 @@ where let cid = self.next_grant_id; self.next_grant_id += 1; let rid = self.behaviour.send_request(p, Message::credit(credit, cid)); - log::trace!("{:08x}: sending {} credit as grant {} to {}", self.id, credit, cid, p); - let grant = Grant { id: cid, request: rid, credit }; + log::trace!( + "{:08x}: sending {} credit as grant {} to {}", + self.id, + credit, + cid, + p + ); + let grant = Grant { + id: cid, + request: rid, + credit, + }; info.recv_budget.grant = Some(grant); info.recv_budget.sent.insert(rid); } @@ -383,13 +411,13 @@ pub enum Event { /// When previously reaching the send limit of a peer, /// this event is eventually emitted when sending is /// allowed to resume. - ResumeSending(PeerId) + ResumeSending(PeerId), } impl NetworkBehaviour for Throttled where C: RequestResponseCodec + Send + Clone + 'static, - C::Protocol: Sync + C::Protocol: Sync, { type ProtocolsHandler = RequestResponseHandler>; type OutEvent = Event>; @@ -402,7 +430,12 @@ where self.behaviour.addresses_of_peer(p) } - fn inject_connection_established(&mut self, p: &PeerId, id: &ConnectionId, end: &ConnectedPoint) { + fn inject_connection_established( + &mut self, + p: &PeerId, + id: &ConnectionId, + end: &ConnectedPoint, + ) { self.behaviour.inject_connection_established(p, id, end) } @@ -433,7 +466,11 @@ where self.send_credit(p, recv_budget - 1); } } else { - let limit = self.limit_overrides.get(p).copied().unwrap_or(self.default_limit); + let limit = self + .limit_overrides + .get(p) + .copied() + .unwrap_or(self.default_limit); self.peer_info.insert(*p, PeerInfo::new(limit)); } } @@ -451,142 +488,183 @@ where self.behaviour.inject_dial_failure(p) } - fn inject_event(&mut self, p: PeerId, i: ConnectionId, e: RequestResponseHandlerEvent>) { + fn inject_event( + &mut self, + p: PeerId, + i: ConnectionId, + e: RequestResponseHandlerEvent>, + ) { self.behaviour.inject_event(p, i, e) } - fn poll(&mut self, cx: &mut Context<'_>, params: &mut impl PollParameters) - -> Poll>, Self::OutEvent>> - { + fn poll( + &mut self, + cx: &mut Context<'_>, + params: &mut impl PollParameters, + ) -> Poll>, Self::OutEvent>> { loop { if let Some(ev) = self.events.pop_front() { - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(ev)) + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(ev)); } else if self.events.capacity() > super::EMPTY_QUEUE_SHRINK_THRESHOLD { self.events.shrink_to_fit() } let event = match ready!(self.behaviour.poll(cx, params)) { - | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::Message { peer, message }) => { + NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::Message { + peer, + message, + }) => { let message = match message { - | RequestResponseMessage::Response { request_id, response } => - match &response.header().typ { - | Some(Type::Ack) => { - if let Some(info) = self.peer_info.get_mut(&peer) { - if let Some(id) = info.recv_budget.grant.as_ref().map(|c| c.id) { - if Some(id) == response.header().ident { - log::trace!("{:08x}: received ack {} from {}", self.id, id, peer); - info.recv_budget.grant = None; - } + RequestResponseMessage::Response { + request_id, + response, + } => match &response.header().typ { + Some(Type::Ack) => { + if let Some(info) = self.peer_info.get_mut(&peer) { + if let Some(id) = info.recv_budget.grant.as_ref().map(|c| c.id) + { + if Some(id) == response.header().ident { + log::trace!( + "{:08x}: received ack {} from {}", + self.id, + id, + peer + ); + info.recv_budget.grant = None; } - info.recv_budget.sent.remove(&request_id); } - continue + info.recv_budget.sent.remove(&request_id); + } + continue; + } + Some(Type::Response) => { + log::trace!( + "{:08x}: received response {} from {}", + self.id, + request_id, + peer + ); + if let Some(rs) = response.into_parts().1 { + RequestResponseMessage::Response { + request_id, + response: rs, + } + } else { + log::error! { "{:08x}: missing data for response {} from peer {}", + self.id, + request_id, + peer + } + continue; } - | Some(Type::Response) => { - log::trace!("{:08x}: received response {} from {}", self.id, request_id, peer); - if let Some(rs) = response.into_parts().1 { - RequestResponseMessage::Response { request_id, response: rs } + } + ty => { + log::trace! { + "{:08x}: unknown message type: {:?} from {}; expected response or credit", + self.id, + ty, + peer + }; + continue; + } + }, + RequestResponseMessage::Request { + request_id, + request, + channel, + } => match &request.header().typ { + Some(Type::Credit) => { + if let Some(info) = self.peer_info.get_mut(&peer) { + let id = if let Some(n) = request.header().ident { + n } else { - log::error! { "{:08x}: missing data for response {} from peer {}", + log::warn! { "{:08x}: missing credit id in message from {}", self.id, - request_id, peer } - continue - } - } - | ty => { - log::trace! { - "{:08x}: unknown message type: {:?} from {}; expected response or credit", + continue; + }; + let credit = request.header().credit.unwrap_or(0); + log::trace! { "{:08x}: received {} additional credit {} from {}", self.id, - ty, + credit, + id, peer }; - continue - } - } - | RequestResponseMessage::Request { request_id, request, channel } => - match &request.header().typ { - | Some(Type::Credit) => { - if let Some(info) = self.peer_info.get_mut(&peer) { - let id = if let Some(n) = request.header().ident { - n - } else { - log::warn! { "{:08x}: missing credit id in message from {}", + if info.send_budget.grant < Some(id) { + if info.send_budget.remaining == 0 && credit > 0 { + log::trace!( + "{:08x}: sending to peer {} can resume", self.id, peer - } - continue - }; - let credit = request.header().credit.unwrap_or(0); - log::trace! { "{:08x}: received {} additional credit {} from {}", - self.id, - credit, - id, - peer - }; - if info.send_budget.grant < Some(id) { - if info.send_budget.remaining == 0 && credit > 0 { - log::trace!("{:08x}: sending to peer {} can resume", self.id, peer); - self.events.push_back(Event::ResumeSending(peer)) - } - info.send_budget.remaining += credit; - info.send_budget.grant = Some(id); + ); + self.events.push_back(Event::ResumeSending(peer)) } - // Note: Failing to send a response to a credit grant is - // handled along with other inbound failures further below. - let _ = self.behaviour.send_response(channel, Message::ack(id)); - info.send_budget.received.insert(request_id); + info.send_budget.remaining += credit; + info.send_budget.grant = Some(id); } - continue + // Note: Failing to send a response to a credit grant is + // handled along with other inbound failures further below. + let _ = self.behaviour.send_response(channel, Message::ack(id)); + info.send_budget.received.insert(request_id); } - | Some(Type::Request) => { - if let Some(info) = self.peer_info.get_mut(&peer) { - log::trace! { "{:08x}: received request {} (recv. budget = {})", - self.id, - request_id, - info.recv_budget.remaining - }; - if info.recv_budget.remaining == 0 { - log::debug!("{:08x}: peer {} exceeds its budget", self.id, peer); - self.events.push_back(Event::TooManyInboundRequests(peer)); - continue - } - info.recv_budget.remaining -= 1; - // We consider a request as proof that our credit grant has - // reached the peer. Usually, an ACK has already been - // received. - info.recv_budget.grant = None; - } - if let Some(rq) = request.into_parts().1 { - RequestResponseMessage::Request { request_id, request: rq, channel } - } else { - log::error! { "{:08x}: missing data for request {} from peer {}", + continue; + } + Some(Type::Request) => { + if let Some(info) = self.peer_info.get_mut(&peer) { + log::trace! { "{:08x}: received request {} (recv. budget = {})", + self.id, + request_id, + info.recv_budget.remaining + }; + if info.recv_budget.remaining == 0 { + log::debug!( + "{:08x}: peer {} exceeds its budget", self.id, - request_id, peer - } - continue + ); + self.events.push_back(Event::TooManyInboundRequests(peer)); + continue; } + info.recv_budget.remaining -= 1; + // We consider a request as proof that our credit grant has + // reached the peer. Usually, an ACK has already been + // received. + info.recv_budget.grant = None; } - | ty => { - log::trace! { - "{:08x}: unknown message type: {:?} from {}; expected request or ack", + if let Some(rq) = request.into_parts().1 { + RequestResponseMessage::Request { + request_id, + request: rq, + channel, + } + } else { + log::error! { "{:08x}: missing data for request {} from peer {}", self.id, - ty, + request_id, peer - }; - continue + } + continue; } } + ty => { + log::trace! { + "{:08x}: unknown message type: {:?} from {}; expected request or ack", + self.id, + ty, + peer + }; + continue; + } + }, }; let event = RequestResponseEvent::Message { peer, message }; NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::OutboundFailure { + NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::OutboundFailure { peer, request_id, - error + error, }) => { if let Some(info) = self.peer_info.get_mut(&peer) { if let Some(grant) = info.recv_budget.grant.as_mut() { @@ -606,16 +684,20 @@ where // If the outbound failure was for a credit message, don't report it on // the public API and retry the sending. if info.recv_budget.sent.remove(&request_id) { - continue + continue; } } - let event = RequestResponseEvent::OutboundFailure { peer, request_id, error }; + let event = RequestResponseEvent::OutboundFailure { + peer, + request_id, + error, + }; NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::InboundFailure { + NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::InboundFailure { peer, request_id, - error + error, }) => { // If the inbound failure occurred in the context of responding to a // credit grant, don't report it on the public API. @@ -625,15 +707,19 @@ where "{:08}: failed to acknowledge credit grant from {}: {:?}", self.id, peer, error }; - continue + continue; } } - let event = RequestResponseEvent::InboundFailure { peer, request_id, error }; + let event = RequestResponseEvent::InboundFailure { + peer, + request_id, + error, + }; NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::ResponseSent { + NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::ResponseSent { peer, - request_id + request_id, }) => { // If this event is for an ACK response that was sent for // the last received credit grant, skip it. @@ -644,25 +730,41 @@ where self.id, info.send_budget.grant, } - continue + continue; } } NetworkBehaviourAction::GenerateEvent(Event::Event( - RequestResponseEvent::ResponseSent { peer, request_id })) + RequestResponseEvent::ResponseSent { peer, request_id }, + )) + } + NetworkBehaviourAction::DialAddress { address } => { + NetworkBehaviourAction::DialAddress { address } + } + NetworkBehaviourAction::DialPeer { peer_id, condition } => { + NetworkBehaviourAction::DialPeer { peer_id, condition } + } + NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + } => NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + }, + NetworkBehaviourAction::ReportObservedAddr { address, score } => { + NetworkBehaviourAction::ReportObservedAddr { address, score } } - | NetworkBehaviourAction::DialAddress { address } => - NetworkBehaviourAction::DialAddress { address }, - | NetworkBehaviourAction::DialPeer { peer_id, condition } => - NetworkBehaviourAction::DialPeer { peer_id, condition }, - | NetworkBehaviourAction::NotifyHandler { peer_id, handler, event } => - NetworkBehaviourAction::NotifyHandler { peer_id, handler, event }, - | NetworkBehaviourAction::ReportObservedAddr { address, score } => - NetworkBehaviourAction::ReportObservedAddr { address, score }, - | NetworkBehaviourAction::CloseConnection { peer_id, connection } => - NetworkBehaviourAction::CloseConnection { peer_id, connection } + NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + } => NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }, }; - return Poll::Ready(event) + return Poll::Ready(event); } } } diff --git a/protocols/request-response/src/throttled/codec.rs b/protocols/request-response/src/throttled/codec.rs index 580fdd3da85..f82c4ae3961 100644 --- a/protocols/request-response/src/throttled/codec.rs +++ b/protocols/request-response/src/throttled/codec.rs @@ -18,13 +18,13 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use super::RequestResponseCodec; use async_trait::async_trait; use bytes::{Bytes, BytesMut}; use futures::prelude::*; use libp2p_core::ProtocolName; -use minicbor::{Encode, Decode}; +use minicbor::{Decode, Encode}; use std::io; -use super::RequestResponseCodec; use unsigned_varint::{aio, io::ReadError}; /// A protocol header. @@ -32,27 +32,34 @@ use unsigned_varint::{aio, io::ReadError}; #[cbor(map)] pub struct Header { /// The type of message. - #[n(0)] pub typ: Option, + #[n(0)] + pub typ: Option, /// The number of additional requests the remote is willing to receive. - #[n(1)] pub credit: Option, + #[n(1)] + pub credit: Option, /// An identifier used for sending credit grants. - #[n(2)] pub ident: Option + #[n(2)] + pub ident: Option, } /// A protocol message type. #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)] pub enum Type { - #[n(0)] Request, - #[n(1)] Response, - #[n(2)] Credit, - #[n(3)] Ack + #[n(0)] + Request, + #[n(1)] + Response, + #[n(2)] + Credit, + #[n(3)] + Ack, } /// A protocol message consisting of header and data. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Message { header: Header, - data: Option + data: Option, } impl Message { @@ -63,26 +70,40 @@ impl Message { /// Create a request message. pub fn request(data: T) -> Self { - let mut m = Message::new(Header { typ: Some(Type::Request), .. Header::default() }); + let mut m = Message::new(Header { + typ: Some(Type::Request), + ..Header::default() + }); m.data = Some(data); m } /// Create a response message. pub fn response(data: T) -> Self { - let mut m = Message::new(Header { typ: Some(Type::Response), .. Header::default() }); + let mut m = Message::new(Header { + typ: Some(Type::Response), + ..Header::default() + }); m.data = Some(data); m } /// Create a credit grant. pub fn credit(credit: u16, ident: u64) -> Self { - Message::new(Header { typ: Some(Type::Credit), credit: Some(credit), ident: Some(ident) }) + Message::new(Header { + typ: Some(Type::Credit), + credit: Some(credit), + ident: Some(ident), + }) } /// Create an acknowledge message. pub fn ack(ident: u64) -> Self { - Message::new(Header { typ: Some(Type::Ack), credit: None, ident: Some(ident) }) + Message::new(Header { + typ: Some(Type::Ack), + credit: None, + ident: Some(ident), + }) } /// Access the message header. @@ -130,28 +151,34 @@ pub struct Codec { /// Encoding/decoding buffer. buffer: Vec, /// Max. header length. - max_header_len: u32 + max_header_len: u32, } impl Codec { /// Create a codec by wrapping an existing one. pub fn new(c: C, max_header_len: u32) -> Self { - Codec { inner: c, buffer: Vec::new(), max_header_len } + Codec { + inner: c, + buffer: Vec::new(), + max_header_len, + } } /// Read and decode a request header. async fn read_header(&mut self, io: &mut T) -> io::Result where T: AsyncRead + Unpin + Send, - H: for<'a> minicbor::Decode<'a> + H: for<'a> minicbor::Decode<'a>, { - let header_len = aio::read_u32(&mut *io).await - .map_err(|e| match e { - ReadError::Io(e) => e, - other => io::Error::new(io::ErrorKind::Other, other) - })?; + let header_len = aio::read_u32(&mut *io).await.map_err(|e| match e { + ReadError::Io(e) => e, + other => io::Error::new(io::ErrorKind::Other, other), + })?; if header_len > self.max_header_len { - return Err(io::Error::new(io::ErrorKind::InvalidData, "header too large to read")) + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "header too large to read", + )); } self.buffer.resize(u32_to_usize(header_len), 0u8); io.read_exact(&mut self.buffer).await?; @@ -162,12 +189,16 @@ impl Codec { async fn write_header(&mut self, hdr: &H, io: &mut T) -> io::Result<()> where T: AsyncWrite + Unpin + Send, - H: minicbor::Encode + H: minicbor::Encode, { self.buffer.clear(); - minicbor::encode(hdr, &mut self.buffer).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + minicbor::encode(hdr, &mut self.buffer) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; if self.buffer.len() > u32_to_usize(self.max_header_len) { - return Err(io::Error::new(io::ErrorKind::InvalidData, "header too large to write")) + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "header too large to write", + )); } let mut b = unsigned_varint::encode::u32_buffer(); let header_len = unsigned_varint::encode::u32(self.buffer.len() as u32, &mut b); @@ -180,7 +211,7 @@ impl Codec { impl RequestResponseCodec for Codec where C: RequestResponseCodec + Send, - C::Protocol: Sync + C::Protocol: Sync, { type Protocol = ProtocolWrapper; type Request = Message; @@ -188,7 +219,7 @@ where async fn read_request(&mut self, p: &Self::Protocol, io: &mut T) -> io::Result where - T: AsyncRead + Unpin + Send + T: AsyncRead + Unpin + Send, { let mut msg = Message::new(self.read_header(io).await?); match msg.header.typ { @@ -198,15 +229,22 @@ where } Some(Type::Credit) => Ok(msg), Some(Type::Response) | Some(Type::Ack) | None => { - log::debug!("unexpected {:?} when expecting request or credit grant", msg.header.typ); + log::debug!( + "unexpected {:?} when expecting request or credit grant", + msg.header.typ + ); Err(io::ErrorKind::InvalidData.into()) } } } - async fn read_response(&mut self, p: &Self::Protocol, io: &mut T) -> io::Result + async fn read_response( + &mut self, + p: &Self::Protocol, + io: &mut T, + ) -> io::Result where - T: AsyncRead + Unpin + Send + T: AsyncRead + Unpin + Send, { let mut msg = Message::new(self.read_header(io).await?); match msg.header.typ { @@ -216,15 +254,23 @@ where } Some(Type::Ack) => Ok(msg), Some(Type::Request) | Some(Type::Credit) | None => { - log::debug!("unexpected {:?} when expecting response or ack", msg.header.typ); + log::debug!( + "unexpected {:?} when expecting response or ack", + msg.header.typ + ); Err(io::ErrorKind::InvalidData.into()) } } } - async fn write_request(&mut self, p: &Self::Protocol, io: &mut T, r: Self::Request) -> io::Result<()> + async fn write_request( + &mut self, + p: &Self::Protocol, + io: &mut T, + r: Self::Request, + ) -> io::Result<()> where - T: AsyncWrite + Unpin + Send + T: AsyncWrite + Unpin + Send, { self.write_header(&r.header, io).await?; if let Some(data) = r.data { @@ -233,9 +279,14 @@ where Ok(()) } - async fn write_response(&mut self, p: &Self::Protocol, io: &mut T, r: Self::Response) -> io::Result<()> + async fn write_response( + &mut self, + p: &Self::Protocol, + io: &mut T, + r: Self::Response, + ) -> io::Result<()> where - T: AsyncWrite + Unpin + Send + T: AsyncWrite + Unpin + Send, { self.write_header(&r.header, io).await?; if let Some(data) = r.data { diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index 43bed41b302..626f4effef3 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -21,22 +21,21 @@ //! Integration tests for the `RequestResponse` network behaviour. use async_trait::async_trait; +use futures::{channel::mpsc, executor::LocalPool, prelude::*, task::SpawnExt, AsyncWriteExt}; use libp2p_core::{ - Multiaddr, - PeerId, identity, muxing::StreamMuxerBox, transport::{self, Transport}, - upgrade::{self, read_length_prefixed, write_length_prefixed} + upgrade::{self, read_length_prefixed, write_length_prefixed}, + Multiaddr, PeerId, }; -use libp2p_noise::{NoiseConfig, X25519Spec, Keypair}; +use libp2p_noise::{Keypair, NoiseConfig, X25519Spec}; use libp2p_request_response::*; use libp2p_swarm::{Swarm, SwarmEvent}; use libp2p_tcp::TcpConfig; -use futures::{channel::mpsc, executor::LocalPool, prelude::*, task::SpawnExt, AsyncWriteExt}; use rand::{self, Rng}; -use std::{io, iter}; use std::{collections::HashSet, num::NonZeroU16}; +use std::{io, iter}; #[test] fn is_response_outbound() { @@ -50,24 +49,30 @@ fn is_response_outbound() { let ping_proto1 = RequestResponse::new(PingCodec(), protocols, cfg); let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id); - let request_id1 = swarm1.behaviour_mut().send_request(&offline_peer, ping.clone()); + let request_id1 = swarm1 + .behaviour_mut() + .send_request(&offline_peer, ping.clone()); match futures::executor::block_on(swarm1.select_next_some()) { - SwarmEvent::Behaviour(RequestResponseEvent::OutboundFailure{ + SwarmEvent::Behaviour(RequestResponseEvent::OutboundFailure { peer, request_id: req_id, - error: _error + error: _error, }) => { assert_eq!(&offline_peer, &peer); assert_eq!(req_id, request_id1); - }, + } e => panic!("Peer: Unexpected event: {:?}", e), } let request_id2 = swarm1.behaviour_mut().send_request(&offline_peer, ping); - assert!(!swarm1.behaviour().is_pending_outbound(&offline_peer, &request_id1)); - assert!(swarm1.behaviour().is_pending_outbound(&offline_peer, &request_id2)); + assert!(!swarm1 + .behaviour() + .is_pending_outbound(&offline_peer, &request_id1)); + assert!(swarm1 + .behaviour() + .is_pending_outbound(&offline_peer, &request_id2)); } /// Exercises a simple ping protocol. @@ -98,18 +103,22 @@ fn ping_protocol() { let peer1 = async move { loop { match swarm1.select_next_some().await { - SwarmEvent::NewListenAddr { address, .. }=> tx.send(address).await.unwrap(), + SwarmEvent::NewListenAddr { address, .. } => tx.send(address).await.unwrap(), SwarmEvent::Behaviour(RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Request { request, channel, .. } + message: + RequestResponseMessage::Request { + request, channel, .. + }, }) => { assert_eq!(&request, &expected_ping); assert_eq!(&peer, &peer2_id); - swarm1.behaviour_mut().send_response(channel, pong.clone()).unwrap(); - }, - SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { - peer, .. - }) => { + swarm1 + .behaviour_mut() + .send_response(channel, pong.clone()) + .unwrap(); + } + SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { peer, .. }) => { assert_eq!(&peer, &peer2_id); } SwarmEvent::Behaviour(e) => panic!("Peer1: Unexpected event: {:?}", e), @@ -131,20 +140,23 @@ fn ping_protocol() { match swarm2.select_next_some().await { SwarmEvent::Behaviour(RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Response { request_id, response } + message: + RequestResponseMessage::Response { + request_id, + response, + }, }) => { count += 1; assert_eq!(&response, &expected_pong); assert_eq!(&peer, &peer1_id); assert_eq!(req_id, request_id); if count >= num_pings { - return + return; } else { req_id = swarm2.behaviour_mut().send_request(&peer1_id, ping.clone()); } - } - SwarmEvent::Behaviour(e) =>panic!("Peer2: Unexpected event: {:?}", e), + SwarmEvent::Behaviour(e) => panic!("Peer2: Unexpected event: {:?}", e), _ => {} } } @@ -207,8 +219,8 @@ fn emits_inbound_connection_closed_failure() { loop { match swarm1.select_next_some().await { - SwarmEvent::Behaviour(RequestResponseEvent::InboundFailure { - error: InboundFailure::ConnectionClosed, + SwarmEvent::Behaviour(RequestResponseEvent::InboundFailure { + error: InboundFailure::ConnectionClosed, .. }) => break, SwarmEvent::Behaviour(e) => panic!("Peer1: Unexpected event: {:?}", e), @@ -273,7 +285,7 @@ fn emits_inbound_connection_closed_if_channel_is_dropped() { let error = match event { RequestResponseEvent::OutboundFailure { error, .. } => error, - e => panic!("unexpected event from peer 2: {:?}", e) + e => panic!("unexpected event from peer 2: {:?}", e), }; assert_eq!(error, OutboundFailure::ConnectionClosed); @@ -306,24 +318,34 @@ fn ping_protocol_throttled() { let limit1: u16 = rand::thread_rng().gen_range(1, 10); let limit2: u16 = rand::thread_rng().gen_range(1, 10); - swarm1.behaviour_mut().set_receive_limit(NonZeroU16::new(limit1).unwrap()); - swarm2.behaviour_mut().set_receive_limit(NonZeroU16::new(limit2).unwrap()); + swarm1 + .behaviour_mut() + .set_receive_limit(NonZeroU16::new(limit1).unwrap()); + swarm2 + .behaviour_mut() + .set_receive_limit(NonZeroU16::new(limit2).unwrap()); let peer1 = async move { - for i in 1 .. { + for i in 1.. { match swarm1.select_next_some().await { SwarmEvent::NewListenAddr { address, .. } => tx.send(address).await.unwrap(), SwarmEvent::Behaviour(throttled::Event::Event(RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Request { request, channel, .. }, + message: + RequestResponseMessage::Request { + request, channel, .. + }, })) => { assert_eq!(&request, &expected_ping); assert_eq!(&peer, &peer2_id); - swarm1.behaviour_mut().send_response(channel, pong.clone()).unwrap(); - }, - SwarmEvent::Behaviour(throttled::Event::Event(RequestResponseEvent::ResponseSent { - peer, .. - })) => { + swarm1 + .behaviour_mut() + .send_response(channel, pong.clone()) + .unwrap(); + } + SwarmEvent::Behaviour(throttled::Event::Event( + RequestResponseEvent::ResponseSent { peer, .. }, + )) => { assert_eq!(&peer, &peer2_id); } SwarmEvent::Behaviour(e) => panic!("Peer1: Unexpected event: {:?}", e), @@ -331,7 +353,9 @@ fn ping_protocol_throttled() { } if i % 31 == 0 { let lim = rand::thread_rng().gen_range(1, 17); - swarm1.behaviour_mut().override_receive_limit(&peer2_id, NonZeroU16::new(lim).unwrap()); + swarm1 + .behaviour_mut() + .override_receive_limit(&peer2_id, NonZeroU16::new(lim).unwrap()); } } }; @@ -348,7 +372,11 @@ fn ping_protocol_throttled() { loop { if !blocked { - while let Some(id) = swarm2.behaviour_mut().send_request(&peer1_id, ping.clone()).ok() { + while let Some(id) = swarm2 + .behaviour_mut() + .send_request(&peer1_id, ping.clone()) + .ok() + { req_ids.insert(id); } blocked = true; @@ -358,19 +386,23 @@ fn ping_protocol_throttled() { assert_eq!(peer, peer1_id); blocked = false } - SwarmEvent::Behaviour(throttled::Event::Event(RequestResponseEvent::Message { + SwarmEvent::Behaviour(throttled::Event::Event(RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Response { request_id, response } + message: + RequestResponseMessage::Response { + request_id, + response, + }, })) => { count += 1; assert_eq!(&response, &expected_pong); assert_eq!(&peer, &peer1_id); assert!(req_ids.remove(&request_id)); if count >= num_pings { - break + break; } } - SwarmEvent::Behaviour(e) =>panic!("Peer2: Unexpected event: {:?}", e), + SwarmEvent::Behaviour(e) => panic!("Peer2: Unexpected event: {:?}", e), _ => {} } } @@ -384,13 +416,18 @@ fn ping_protocol_throttled() { fn mk_transport() -> (PeerId, transport::Boxed<(PeerId, StreamMuxerBox)>) { let id_keys = identity::Keypair::generate_ed25519(); let peer_id = id_keys.public().to_peer_id(); - let noise_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); - (peer_id, TcpConfig::new() - .nodelay(true) - .upgrade(upgrade::Version::V1) - .authenticate(NoiseConfig::xx(noise_keys).into_authenticated()) - .multiplex(libp2p_yamux::YamuxConfig::default()) - .boxed()) + let noise_keys = Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); + ( + peer_id, + TcpConfig::new() + .nodelay(true) + .upgrade(upgrade::Version::V1) + .authenticate(NoiseConfig::xx(noise_keys).into_authenticated()) + .multiplex(libp2p_yamux::YamuxConfig::default()) + .boxed(), + ) } // Simple Ping-Pong Protocol @@ -416,38 +453,40 @@ impl RequestResponseCodec for PingCodec { type Request = Ping; type Response = Pong; - async fn read_request(&mut self, _: &PingProtocol, io: &mut T) - -> io::Result + async fn read_request(&mut self, _: &PingProtocol, io: &mut T) -> io::Result where - T: AsyncRead + Unpin + Send + T: AsyncRead + Unpin + Send, { let vec = read_length_prefixed(io, 1024).await?; if vec.is_empty() { - return Err(io::ErrorKind::UnexpectedEof.into()) + return Err(io::ErrorKind::UnexpectedEof.into()); } Ok(Ping(vec)) } - async fn read_response(&mut self, _: &PingProtocol, io: &mut T) - -> io::Result + async fn read_response(&mut self, _: &PingProtocol, io: &mut T) -> io::Result where - T: AsyncRead + Unpin + Send + T: AsyncRead + Unpin + Send, { let vec = read_length_prefixed(io, 1024).await?; if vec.is_empty() { - return Err(io::ErrorKind::UnexpectedEof.into()) + return Err(io::ErrorKind::UnexpectedEof.into()); } Ok(Pong(vec)) } - async fn write_request(&mut self, _: &PingProtocol, io: &mut T, Ping(data): Ping) - -> io::Result<()> + async fn write_request( + &mut self, + _: &PingProtocol, + io: &mut T, + Ping(data): Ping, + ) -> io::Result<()> where - T: AsyncWrite + Unpin + Send + T: AsyncWrite + Unpin + Send, { write_length_prefixed(io, data).await?; io.close().await?; @@ -455,10 +494,14 @@ impl RequestResponseCodec for PingCodec { Ok(()) } - async fn write_response(&mut self, _: &PingProtocol, io: &mut T, Pong(data): Pong) - -> io::Result<()> + async fn write_response( + &mut self, + _: &PingProtocol, + io: &mut T, + Pong(data): Pong, + ) -> io::Result<()> where - T: AsyncWrite + Unpin + Send + T: AsyncWrite + Unpin + Send, { write_length_prefixed(io, data).await?; io.close().await?; diff --git a/src/bandwidth.rs b/src/bandwidth.rs index 87b66653cfc..a341b4dfbab 100644 --- a/src/bandwidth.rs +++ b/src/bandwidth.rs @@ -18,12 +18,26 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{Multiaddr, core::{Transport, transport::{ListenerEvent, TransportError}}}; +use crate::{ + core::{ + transport::{ListenerEvent, TransportError}, + Transport, + }, + Multiaddr, +}; use atomic::Atomic; -use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready}; +use futures::{ + io::{IoSlice, IoSliceMut}, + prelude::*, + ready, +}; use std::{ - convert::TryFrom as _, io, pin::Pin, sync::{atomic::Ordering, Arc}, task::{Context, Poll} + convert::TryFrom as _, + io, + pin::Pin, + sync::{atomic::Ordering, Arc}, + task::{Context, Poll}, }; /// Wraps around a `Transport` and counts the number of bytes that go through all the opened @@ -91,19 +105,18 @@ pub struct BandwidthListener { impl Stream for BandwidthListener where - TInner: TryStream, Error = TErr> + TInner: TryStream, Error = TErr>, { type Item = Result, TErr>, TErr>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - let event = - if let Some(event) = ready!(this.inner.try_poll_next(cx)?) { - event - } else { - return Poll::Ready(None) - }; + let event = if let Some(event) = ready!(this.inner.try_poll_next(cx)?) { + event + } else { + return Poll::Ready(None); + }; let event = event.map({ let sinks = this.sinks.clone(); @@ -129,7 +142,10 @@ impl Future for BandwidthFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let inner = ready!(this.inner.try_poll(cx)?); - let logged = BandwidthConnecLogging { inner, sinks: this.sinks.clone() }; + let logged = BandwidthConnecLogging { + inner, + sinks: this.sinks.clone(), + }; Poll::Ready(Ok(logged)) } } @@ -169,33 +185,61 @@ pub struct BandwidthConnecLogging { } impl AsyncRead for BandwidthConnecLogging { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { let this = self.project(); let num_bytes = ready!(this.inner.poll_read(cx, buf))?; - this.sinks.inbound.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + this.sinks.inbound.fetch_add( + u64::try_from(num_bytes).unwrap_or(u64::max_value()), + Ordering::Relaxed, + ); Poll::Ready(Ok(num_bytes)) } - fn poll_read_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>]) -> Poll> { + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { let this = self.project(); let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?; - this.sinks.inbound.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + this.sinks.inbound.fetch_add( + u64::try_from(num_bytes).unwrap_or(u64::max_value()), + Ordering::Relaxed, + ); Poll::Ready(Ok(num_bytes)) } } impl AsyncWrite for BandwidthConnecLogging { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { let this = self.project(); let num_bytes = ready!(this.inner.poll_write(cx, buf))?; - this.sinks.outbound.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + this.sinks.outbound.fetch_add( + u64::try_from(num_bytes).unwrap_or(u64::max_value()), + Ordering::Relaxed, + ); Poll::Ready(Ok(num_bytes)) } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll> { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { let this = self.project(); let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?; - this.sinks.outbound.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + this.sinks.outbound.fetch_add( + u64::try_from(num_bytes).unwrap_or(u64::max_value()), + Ordering::Relaxed, + ); Poll::Ready(Ok(num_bytes)) } diff --git a/src/lib.rs b/src/lib.rs index e675b40e7f0..fc48d34b83d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,9 +36,9 @@ pub use bytes; pub use futures; #[doc(inline)] -pub use multiaddr; -#[doc(inline)] pub use libp2p_core::multihash; +#[doc(inline)] +pub use multiaddr; #[doc(inline)] pub use libp2p_core as core; @@ -48,18 +48,13 @@ pub use libp2p_core as core; #[doc(inline)] pub use libp2p_deflate as deflate; #[cfg(any(feature = "dns-async-std", feature = "dns-tokio"))] -#[cfg_attr(docsrs, doc(cfg(any(feature = "dns-async-std", feature = "dns-tokio"))))] +#[cfg_attr( + docsrs, + doc(cfg(any(feature = "dns-async-std", feature = "dns-tokio"))) +)] #[cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))] #[doc(inline)] pub use libp2p_dns as dns; -#[cfg(feature = "identify")] -#[cfg_attr(docsrs, doc(cfg(feature = "identify")))] -#[doc(inline)] -pub use libp2p_identify as identify; -#[cfg(feature = "kad")] -#[cfg_attr(docsrs, doc(cfg(feature = "kad")))] -#[doc(inline)] -pub use libp2p_kad as kad; #[cfg(feature = "floodsub")] #[cfg_attr(docsrs, doc(cfg(feature = "floodsub")))] #[doc(inline)] @@ -68,15 +63,23 @@ pub use libp2p_floodsub as floodsub; #[cfg_attr(docsrs, doc(cfg(feature = "gossipsub")))] #[doc(inline)] pub use libp2p_gossipsub as gossipsub; -#[cfg(feature = "mplex")] -#[cfg_attr(docsrs, doc(cfg(feature = "mplex")))] +#[cfg(feature = "identify")] +#[cfg_attr(docsrs, doc(cfg(feature = "identify")))] #[doc(inline)] -pub use libp2p_mplex as mplex; +pub use libp2p_identify as identify; +#[cfg(feature = "kad")] +#[cfg_attr(docsrs, doc(cfg(feature = "kad")))] +#[doc(inline)] +pub use libp2p_kad as kad; #[cfg(feature = "mdns")] #[cfg_attr(docsrs, doc(cfg(feature = "mdns")))] #[cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))] #[doc(inline)] pub use libp2p_mdns as mdns; +#[cfg(feature = "mplex")] +#[cfg_attr(docsrs, doc(cfg(feature = "mplex")))] +#[doc(inline)] +pub use libp2p_mplex as mplex; #[cfg(feature = "noise")] #[cfg_attr(docsrs, doc(cfg(feature = "noise")))] #[doc(inline)] @@ -89,6 +92,18 @@ pub use libp2p_ping as ping; #[cfg_attr(docsrs, doc(cfg(feature = "plaintext")))] #[doc(inline)] pub use libp2p_plaintext as plaintext; +#[cfg(feature = "pnet")] +#[cfg_attr(docsrs, doc(cfg(feature = "pnet")))] +#[doc(inline)] +pub use libp2p_pnet as pnet; +#[cfg(feature = "relay")] +#[cfg_attr(docsrs, doc(cfg(feature = "relay")))] +#[doc(inline)] +pub use libp2p_relay as relay; +#[cfg(feature = "request-response")] +#[cfg_attr(docsrs, doc(cfg(feature = "request-response")))] +#[doc(inline)] +pub use libp2p_request_response as request_response; #[doc(inline)] pub use libp2p_swarm as swarm; #[cfg(any(feature = "tcp-async-io", feature = "tcp-tokio"))] @@ -113,18 +128,6 @@ pub use libp2p_websocket as websocket; #[cfg_attr(docsrs, doc(cfg(feature = "yamux")))] #[doc(inline)] pub use libp2p_yamux as yamux; -#[cfg(feature = "pnet")] -#[cfg_attr(docsrs, doc(cfg(feature = "pnet")))] -#[doc(inline)] -pub use libp2p_pnet as pnet; -#[cfg(feature = "relay")] -#[cfg_attr(docsrs, doc(cfg(feature = "relay")))] -#[doc(inline)] -pub use libp2p_relay as relay; -#[cfg(feature = "request-response")] -#[cfg_attr(docsrs, doc(cfg(feature = "request-response")))] -#[doc(inline)] -pub use libp2p_request_response as request_response; mod transport_ext; @@ -136,16 +139,15 @@ pub mod tutorial; pub use self::core::{ identity, - PeerId, - Transport, transport::TransportError, - upgrade::{InboundUpgrade, InboundUpgradeExt, OutboundUpgrade, OutboundUpgradeExt} + upgrade::{InboundUpgrade, InboundUpgradeExt, OutboundUpgrade, OutboundUpgradeExt}, + PeerId, Transport, }; -pub use libp2p_swarm_derive::NetworkBehaviour; -pub use self::multiaddr::{Multiaddr, multiaddr as build_multiaddr}; +pub use self::multiaddr::{multiaddr as build_multiaddr, Multiaddr}; pub use self::simple::SimpleProtocol; pub use self::swarm::Swarm; pub use self::transport_ext::TransportExt; +pub use libp2p_swarm_derive::NetworkBehaviour; /// Builds a `Transport` based on TCP/IP that supports the most commonly-used features of libp2p: /// @@ -158,11 +160,30 @@ pub use self::transport_ext::TransportExt; /// /// > **Note**: This `Transport` is not suitable for production usage, as its implementation /// > reserves the right to support additional protocols or remove deprecated protocols. -#[cfg(all(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), feature = "tcp-async-io", feature = "dns-async-std", feature = "websocket", feature = "noise", feature = "mplex", feature = "yamux"))] -#[cfg_attr(docsrs, doc(cfg(all(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), feature = "tcp-async-io", feature = "dns-async-std", feature = "websocket", feature = "noise", feature = "mplex", feature = "yamux"))))] -pub async fn development_transport(keypair: identity::Keypair) - -> std::io::Result> -{ +#[cfg(all( + not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), + feature = "tcp-async-io", + feature = "dns-async-std", + feature = "websocket", + feature = "noise", + feature = "mplex", + feature = "yamux" +))] +#[cfg_attr( + docsrs, + doc(cfg(all( + not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), + feature = "tcp-async-io", + feature = "dns-async-std", + feature = "websocket", + feature = "noise", + feature = "mplex", + feature = "yamux" + ))) +)] +pub async fn development_transport( + keypair: identity::Keypair, +) -> std::io::Result> { let transport = { let tcp = tcp::TcpConfig::new().nodelay(true); let dns_tcp = dns::DnsConfig::system(tcp).await?; @@ -177,7 +198,10 @@ pub async fn development_transport(keypair: identity::Keypair) Ok(transport .upgrade(core::upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) - .multiplex(core::upgrade::SelectUpgrade::new(yamux::YamuxConfig::default(), mplex::MplexConfig::default())) + .multiplex(core::upgrade::SelectUpgrade::new( + yamux::YamuxConfig::default(), + mplex::MplexConfig::default(), + )) .timeout(std::time::Duration::from_secs(20)) .boxed()) } @@ -193,11 +217,30 @@ pub async fn development_transport(keypair: identity::Keypair) /// /// > **Note**: This `Transport` is not suitable for production usage, as its implementation /// > reserves the right to support additional protocols or remove deprecated protocols. -#[cfg(all(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), feature = "tcp-tokio", feature = "dns-tokio", feature = "websocket", feature = "noise", feature = "mplex", feature = "yamux"))] -#[cfg_attr(docsrs, doc(cfg(all(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), feature = "tcp-tokio", feature = "dns-tokio", feature = "websocket", feature = "noise", feature = "mplex", feature = "yamux"))))] -pub fn tokio_development_transport(keypair: identity::Keypair) - -> std::io::Result> -{ +#[cfg(all( + not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), + feature = "tcp-tokio", + feature = "dns-tokio", + feature = "websocket", + feature = "noise", + feature = "mplex", + feature = "yamux" +))] +#[cfg_attr( + docsrs, + doc(cfg(all( + not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), + feature = "tcp-tokio", + feature = "dns-tokio", + feature = "websocket", + feature = "noise", + feature = "mplex", + feature = "yamux" + ))) +)] +pub fn tokio_development_transport( + keypair: identity::Keypair, +) -> std::io::Result> { let transport = { let tcp = tcp::TokioTcpConfig::new().nodelay(true); let dns_tcp = dns::TokioDnsConfig::system(tcp)?; @@ -212,7 +255,10 @@ pub fn tokio_development_transport(keypair: identity::Keypair) Ok(transport .upgrade(core::upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) - .multiplex(core::upgrade::SelectUpgrade::new(yamux::YamuxConfig::default(), mplex::MplexConfig::default())) + .multiplex(core::upgrade::SelectUpgrade::new( + yamux::YamuxConfig::default(), + mplex::MplexConfig::default(), + )) .timeout(std::time::Duration::from_secs(20)) .boxed()) } diff --git a/src/transport_ext.rs b/src/transport_ext.rs index de77007b9c4..fa8926c8380 100644 --- a/src/transport_ext.rs +++ b/src/transport_ext.rs @@ -33,7 +33,7 @@ pub trait TransportExt: Transport { /// of bytes transferred through the sockets. fn with_bandwidth_logging(self) -> (BandwidthLogging, Arc) where - Self: Sized + Self: Sized, { BandwidthLogging::new(self) } diff --git a/src/tutorial.rs b/src/tutorial.rs index 9d88bf54c9d..ddaa71350bc 100644 --- a/src/tutorial.rs +++ b/src/tutorial.rs @@ -349,8 +349,8 @@ //! //! Note: The [`Multiaddr`] at the end being one of the [`Multiaddr`] printed //! earlier in terminal window one. -//! Both peers have to be in the same network with which the address is associated. -//! In our case any printed addresses can be used, as both peers run on the same +//! Both peers have to be in the same network with which the address is associated. +//! In our case any printed addresses can be used, as both peers run on the same //! device. //! //! The two nodes will establish a connection and send each other ping and pong diff --git a/swarm-derive/src/lib.rs b/swarm-derive/src/lib.rs index a5cdf4900ca..3f92e549fa9 100644 --- a/swarm-derive/src/lib.rs +++ b/swarm-derive/src/lib.rs @@ -20,9 +20,9 @@ #![recursion_limit = "256"] -use quote::quote; use proc_macro::TokenStream; -use syn::{parse_macro_input, DeriveInput, Data, DataStruct, Ident}; +use quote::quote; +use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Ident}; /// Generates a delegating `NetworkBehaviour` implementation for the struct this is used for. See /// the trait documentation for better description. @@ -45,27 +45,27 @@ fn build(ast: &DeriveInput) -> TokenStream { fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { let name = &ast.ident; let (_, ty_generics, where_clause) = ast.generics.split_for_impl(); - let multiaddr = quote!{::libp2p::core::Multiaddr}; - let trait_to_impl = quote!{::libp2p::swarm::NetworkBehaviour}; - let net_behv_event_proc = quote!{::libp2p::swarm::NetworkBehaviourEventProcess}; - let either_ident = quote!{::libp2p::core::either::EitherOutput}; - let network_behaviour_action = quote!{::libp2p::swarm::NetworkBehaviourAction}; - let into_protocols_handler = quote!{::libp2p::swarm::IntoProtocolsHandler}; - let protocols_handler = quote!{::libp2p::swarm::ProtocolsHandler}; - let into_proto_select_ident = quote!{::libp2p::swarm::IntoProtocolsHandlerSelect}; - let peer_id = quote!{::libp2p::core::PeerId}; - let connection_id = quote!{::libp2p::core::connection::ConnectionId}; - let connected_point = quote!{::libp2p::core::ConnectedPoint}; - let listener_id = quote!{::libp2p::core::connection::ListenerId}; - - let poll_parameters = quote!{::libp2p::swarm::PollParameters}; + let multiaddr = quote! {::libp2p::core::Multiaddr}; + let trait_to_impl = quote! {::libp2p::swarm::NetworkBehaviour}; + let net_behv_event_proc = quote! {::libp2p::swarm::NetworkBehaviourEventProcess}; + let either_ident = quote! {::libp2p::core::either::EitherOutput}; + let network_behaviour_action = quote! {::libp2p::swarm::NetworkBehaviourAction}; + let into_protocols_handler = quote! {::libp2p::swarm::IntoProtocolsHandler}; + let protocols_handler = quote! {::libp2p::swarm::ProtocolsHandler}; + let into_proto_select_ident = quote! {::libp2p::swarm::IntoProtocolsHandlerSelect}; + let peer_id = quote! {::libp2p::core::PeerId}; + let connection_id = quote! {::libp2p::core::connection::ConnectionId}; + let connected_point = quote! {::libp2p::core::ConnectedPoint}; + let listener_id = quote! {::libp2p::core::connection::ListenerId}; + + let poll_parameters = quote! {::libp2p::swarm::PollParameters}; // Build the generics. let impl_generics = { let tp = ast.generics.type_params(); let lf = ast.generics.lifetimes(); let cst = ast.generics.const_params(); - quote!{<#(#lf,)* #(#tp,)* #(#cst,)*>} + quote! {<#(#lf,)* #(#tp,)* #(#cst,)*>} }; // Whether or not we require the `NetworkBehaviourEventProcess` trait to be implemented. @@ -75,12 +75,14 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { for meta_items in ast.attrs.iter().filter_map(get_meta_items) { for meta_item in meta_items { match meta_item { - syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("event_process") => { + syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) + if m.path.is_ident("event_process") => + { if let syn::Lit::Bool(ref b) = m.lit { event_process = b.value } } - _ => () + _ => (), } } } @@ -92,17 +94,19 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { // If we find a `#[behaviour(out_event = "Foo")]` attribute on the struct, we set `Foo` as // the out event. Otherwise we use `()`. let out_event = { - let mut out = quote!{()}; + let mut out = quote! {()}; for meta_items in ast.attrs.iter().filter_map(get_meta_items) { for meta_item in meta_items { match meta_item { - syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("out_event") => { + syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) + if m.path.is_ident("out_event") => + { if let syn::Lit::Str(ref s) = m.lit { let ident: syn::Type = syn::parse_str(&s.value()).unwrap(); - out = quote!{#ident}; + out = quote! {#ident}; } } - _ => () + _ => (), } } } @@ -111,70 +115,84 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { // Build the `where ...` clause of the trait implementation. let where_clause = { - let additional = data_struct.fields.iter() + let additional = data_struct + .fields + .iter() .filter(|x| !is_ignored(x)) .flat_map(|field| { let ty = &field.ty; vec![ - quote!{#ty: #trait_to_impl}, + quote! {#ty: #trait_to_impl}, if event_process { - quote!{Self: #net_behv_event_proc<<#ty as #trait_to_impl>::OutEvent>} + quote! {Self: #net_behv_event_proc<<#ty as #trait_to_impl>::OutEvent>} } else { - quote!{#out_event: From< <#ty as #trait_to_impl>::OutEvent >} - } + quote! {#out_event: From< <#ty as #trait_to_impl>::OutEvent >} + }, ] }) .collect::>(); if let Some(where_clause) = where_clause { if where_clause.predicates.trailing_punct() { - Some(quote!{#where_clause #(#additional),*}) + Some(quote! {#where_clause #(#additional),*}) } else { - Some(quote!{#where_clause, #(#additional),*}) + Some(quote! {#where_clause, #(#additional),*}) } } else { - Some(quote!{where #(#additional),*}) + Some(quote! {where #(#additional),*}) } }; // Build the list of statements to put in the body of `addresses_of_peer()`. let addresses_of_peer_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ out.extend(self.#i.addresses_of_peer(peer_id)); }, - None => quote!{ out.extend(self.#field_n.addresses_of_peer(peer_id)); }, + Some(match field.ident { + Some(ref i) => quote! { out.extend(self.#i.addresses_of_peer(peer_id)); }, + None => quote! { out.extend(self.#field_n.addresses_of_peer(peer_id)); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_connected()`. let inject_connected_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_connected(peer_id); }, - None => quote!{ self.#field_n.inject_connected(peer_id); }, + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_connected(peer_id); }, + None => quote! { self.#field_n.inject_connected(peer_id); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_disconnected()`. let inject_disconnected_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_disconnected(peer_id); }, - None => quote!{ self.#field_n.inject_disconnected(peer_id); }, + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_disconnected(peer_id); }, + None => quote! { self.#field_n.inject_disconnected(peer_id); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_connection_established()`. @@ -217,8 +235,9 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { }; // Build the list of statements to put in the body of `inject_addr_reach_failure()`. - let inject_addr_reach_failure_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { + let inject_addr_reach_failure_stmts = + { + data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { if is_ignored(&field) { return None; } @@ -228,116 +247,148 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { None => quote!{ self.#field_n.inject_addr_reach_failure(peer_id, addr, error); }, }) }) - }; + }; // Build the list of statements to put in the body of `inject_dial_failure()`. let inject_dial_failure_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_dial_failure(peer_id); }, - None => quote!{ self.#field_n.inject_dial_failure(peer_id); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_dial_failure(peer_id); }, + None => quote! { self.#field_n.inject_dial_failure(peer_id); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_new_listener()`. let inject_new_listener_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_new_listener(id); }, - None => quote!{ self.#field_n.inject_new_listener(id); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_new_listener(id); }, + None => quote! { self.#field_n.inject_new_listener(id); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_new_listen_addr()`. let inject_new_listen_addr_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_new_listen_addr(id, addr); }, - None => quote!{ self.#field_n.inject_new_listen_addr(id, addr); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_new_listen_addr(id, addr); }, + None => quote! { self.#field_n.inject_new_listen_addr(id, addr); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_expired_listen_addr()`. let inject_expired_listen_addr_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_expired_listen_addr(id, addr); }, - None => quote!{ self.#field_n.inject_expired_listen_addr(id, addr); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_expired_listen_addr(id, addr); }, + None => quote! { self.#field_n.inject_expired_listen_addr(id, addr); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_new_external_addr()`. let inject_new_external_addr_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_new_external_addr(addr); }, - None => quote!{ self.#field_n.inject_new_external_addr(addr); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_new_external_addr(addr); }, + None => quote! { self.#field_n.inject_new_external_addr(addr); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_expired_external_addr()`. let inject_expired_external_addr_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_expired_external_addr(addr); }, - None => quote!{ self.#field_n.inject_expired_external_addr(addr); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_expired_external_addr(addr); }, + None => quote! { self.#field_n.inject_expired_external_addr(addr); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_listener_error()`. let inject_listener_error_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None - } - Some(match field.ident { - Some(ref i) => quote!(self.#i.inject_listener_error(id, err);), - None => quote!(self.#field_n.inject_listener_error(id, err);) + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } + Some(match field.ident { + Some(ref i) => quote!(self.#i.inject_listener_error(id, err);), + None => quote!(self.#field_n.inject_listener_error(id, err);), + }) }) - }) }; // Build the list of statements to put in the body of `inject_listener_closed()`. let inject_listener_closed_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None - } - Some(match field.ident { - Some(ref i) => quote!(self.#i.inject_listener_closed(id, reason);), - None => quote!(self.#field_n.inject_listener_closed(id, reason);) + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } + Some(match field.ident { + Some(ref i) => quote!(self.#i.inject_listener_closed(id, reason);), + None => quote!(self.#field_n.inject_listener_closed(id, reason);), + }) }) - }) }; // Build the list of variants to put in the body of `inject_event()`. @@ -369,13 +420,13 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { continue; } let ty = &field.ty; - let field_info = quote!{ <#ty as #trait_to_impl>::ProtocolsHandler }; + let field_info = quote! { <#ty as #trait_to_impl>::ProtocolsHandler }; match ph_ty { - Some(ev) => ph_ty = Some(quote!{ #into_proto_select_ident<#ev, #field_info> }), + Some(ev) => ph_ty = Some(quote! { #into_proto_select_ident<#ev, #field_info> }), ref mut ev @ None => *ev = Some(field_info), } } - ph_ty.unwrap_or(quote!{()}) // TODO: `!` instead + ph_ty.unwrap_or(quote! {()}) // TODO: `!` instead }; // The content of `new_handler()`. @@ -389,8 +440,8 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { } let field_name = match field.ident { - Some(ref i) => quote!{ self.#i }, - None => quote!{ self.#field_n }, + Some(ref i) => quote! { self.#i }, + None => quote! { self.#field_n }, }; let builder = quote! { @@ -398,29 +449,33 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { }; match out_handler { - Some(h) => out_handler = Some(quote!{ #into_protocols_handler::select(#h, #builder) }), + Some(h) => { + out_handler = Some(quote! { #into_protocols_handler::select(#h, #builder) }) + } ref mut h @ None => *h = Some(builder), } } - out_handler.unwrap_or(quote!{()}) // TODO: incorrect + out_handler.unwrap_or(quote! {()}) // TODO: incorrect }; // The method to use to poll. // If we find a `#[behaviour(poll_method = "poll")]` attribute on the struct, we call // `self.poll()` at the end of the polling. let poll_method = { - let mut poll_method = quote!{std::task::Poll::Pending}; + let mut poll_method = quote! {std::task::Poll::Pending}; for meta_items in ast.attrs.iter().filter_map(get_meta_items) { for meta_item in meta_items { match meta_item { - syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("poll_method") => { + syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) + if m.path.is_ident("poll_method") => + { if let syn::Lit::Str(ref s) = m.lit { let ident: Ident = syn::parse_str(&s.value()).unwrap(); - poll_method = quote!{#name::#ident(self, cx, poll_params)}; + poll_method = quote! {#name::#ident(self, cx, poll_params)}; } } - _ => () + _ => (), } } } @@ -489,7 +544,7 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { }); // Now the magic happens. - let final_quote = quote!{ + let final_quote = quote! { impl #impl_generics #trait_to_impl for #name #ty_generics #where_clause { @@ -609,7 +664,7 @@ fn is_ignored(field: &syn::Field) -> bool { syn::NestedMeta::Meta(syn::Meta::Path(ref m)) if m.is_ident("ignore") => { return true; } - _ => () + _ => (), } } } diff --git a/swarm-derive/tests/test.rs b/swarm-derive/tests/test.rs index e1913b7eab9..78a9ed985f9 100644 --- a/swarm-derive/tests/test.rs +++ b/swarm-derive/tests/test.rs @@ -43,8 +43,7 @@ fn one_field() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } #[allow(dead_code)] @@ -63,13 +62,11 @@ fn two_fields() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } #[allow(dead_code)] @@ -91,18 +88,15 @@ fn three_fields() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::kad::KademliaEvent) { - } + fn inject_event(&mut self, _: libp2p::kad::KademliaEvent) {} } #[allow(dead_code)] @@ -123,13 +117,11 @@ fn three_fields_non_last_ignored() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::kad::KademliaEvent) { - } + fn inject_event(&mut self, _: libp2p::kad::KademliaEvent) {} } #[allow(dead_code)] @@ -149,17 +141,21 @@ fn custom_polling() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } impl Foo { - fn foo(&mut self, _: &mut std::task::Context, _: &mut impl libp2p::swarm::PollParameters) -> std::task::Poll> { std::task::Poll::Pending } + fn foo( + &mut self, + _: &mut std::task::Context, + _: &mut impl libp2p::swarm::PollParameters, + ) -> std::task::Poll> { + std::task::Poll::Pending + } } #[allow(dead_code)] @@ -179,13 +175,11 @@ fn custom_event_no_polling() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } #[allow(dead_code)] @@ -205,17 +199,21 @@ fn custom_event_and_polling() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } impl Foo { - fn foo(&mut self, _: &mut std::task::Context, _: &mut impl libp2p::swarm::PollParameters) -> std::task::Poll> { std::task::Poll::Pending } + fn foo( + &mut self, + _: &mut std::task::Context, + _: &mut impl libp2p::swarm::PollParameters, + ) -> std::task::Poll> { + std::task::Poll::Pending + } } #[allow(dead_code)] @@ -251,13 +249,11 @@ fn nested_derives_with_import() { } impl NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl NetworkBehaviourEventProcess<()> for Bar { - fn inject_event(&mut self, _: ()) { - } + fn inject_event(&mut self, _: ()) {} } #[allow(dead_code)] @@ -270,7 +266,7 @@ fn nested_derives_with_import() { fn event_process_false() { enum BehaviourOutEvent { Ping(libp2p::ping::PingEvent), - Identify(libp2p::identify::IdentifyEvent) + Identify(libp2p::identify::IdentifyEvent), } impl From for BehaviourOutEvent { @@ -302,7 +298,7 @@ fn event_process_false() { // check that the event is bubbled up all the way to swarm let _ = async { loop { - match _swarm.select_next_some().await { + match _swarm.select_next_some().await { SwarmEvent::Behaviour(BehaviourOutEvent::Ping(_)) => break, SwarmEvent::Behaviour(BehaviourOutEvent::Identify(_)) => break, _ => {} diff --git a/swarm/src/behaviour.rs b/swarm/src/behaviour.rs index 41cf11ffc8e..a21c7a023b8 100644 --- a/swarm/src/behaviour.rs +++ b/swarm/src/behaviour.rs @@ -18,9 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{AddressScore, AddressRecord}; use crate::protocols_handler::{IntoProtocolsHandler, ProtocolsHandler}; -use libp2p_core::{ConnectedPoint, Multiaddr, PeerId, connection::{ConnectionId, ListenerId}}; +use crate::{AddressRecord, AddressScore}; +use libp2p_core::{ + connection::{ConnectionId, ListenerId}, + ConnectedPoint, Multiaddr, PeerId, +}; use std::{error, task::Context, task::Poll}; /// A behaviour for the network. Allows customizing the swarm. @@ -90,7 +93,7 @@ pub trait NetworkBehaviour: Send + 'static { /// /// This method is only called when the first connection to the peer is established, preceded by /// [`inject_connection_established`](NetworkBehaviour::inject_connection_established). - fn inject_connected(&mut self, _: &PeerId) { } + fn inject_connected(&mut self, _: &PeerId) {} /// Indicates to the behaviour that we disconnected from the node with the given peer id. /// @@ -99,19 +102,17 @@ pub trait NetworkBehaviour: Send + 'static { /// /// This method is only called when the last established connection to the peer is closed, /// preceded by [`inject_connection_closed`](NetworkBehaviour::inject_connection_closed). - fn inject_disconnected(&mut self, _: &PeerId) { } + fn inject_disconnected(&mut self, _: &PeerId) {} /// Informs the behaviour about a newly established connection to a peer. - fn inject_connection_established(&mut self, _: &PeerId, _: &ConnectionId, _: &ConnectedPoint) - {} + fn inject_connection_established(&mut self, _: &PeerId, _: &ConnectionId, _: &ConnectedPoint) {} /// Informs the behaviour about a closed connection to a peer. /// /// A call to this method is always paired with an earlier call to /// `inject_connection_established` with the same peer ID, connection ID and /// endpoint. - fn inject_connection_closed(&mut self, _: &PeerId, _: &ConnectionId, _: &ConnectedPoint) - {} + fn inject_connection_closed(&mut self, _: &PeerId, _: &ConnectionId, _: &ConnectedPoint) {} /// Informs the behaviour that the [`ConnectedPoint`] of an existing connection has changed. fn inject_address_change( @@ -119,8 +120,9 @@ pub trait NetworkBehaviour: Send + 'static { _: &PeerId, _: &ConnectionId, _old: &ConnectedPoint, - _new: &ConnectedPoint - ) {} + _new: &ConnectedPoint, + ) { + } /// Informs the behaviour about an event generated by the handler dedicated to the peer identified by `peer_id`. /// for the behaviour. @@ -131,14 +133,19 @@ pub trait NetworkBehaviour: Send + 'static { &mut self, peer_id: PeerId, connection: ConnectionId, - event: <::Handler as ProtocolsHandler>::OutEvent + event: <::Handler as ProtocolsHandler>::OutEvent, ); /// Indicates to the behaviour that we tried to reach an address, but failed. /// /// If we were trying to reach a specific node, its ID is passed as parameter. If this is the /// last address to attempt for the given node, then `inject_dial_failure` is called afterwards. - fn inject_addr_reach_failure(&mut self, _peer_id: Option<&PeerId>, _addr: &Multiaddr, _error: &dyn error::Error) { + fn inject_addr_reach_failure( + &mut self, + _peer_id: Option<&PeerId>, + _addr: &Multiaddr, + _error: &dyn error::Error, + ) { } /// Indicates to the behaviour that we tried to dial all the addresses known for a node, but @@ -146,37 +153,30 @@ pub trait NetworkBehaviour: Send + 'static { /// /// The `peer_id` is guaranteed to be in a disconnected state. In other words, /// `inject_connected` has not been called, or `inject_disconnected` has been called since then. - fn inject_dial_failure(&mut self, _peer_id: &PeerId) { - } + fn inject_dial_failure(&mut self, _peer_id: &PeerId) {} /// Indicates to the behaviour that a new listener was created. - fn inject_new_listener(&mut self, _id: ListenerId) { - } + fn inject_new_listener(&mut self, _id: ListenerId) {} /// Indicates to the behaviour that we have started listening on a new multiaddr. - fn inject_new_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) { - } + fn inject_new_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) {} /// Indicates to the behaviour that a multiaddr we were listening on has expired, /// which means that we are no longer listening in it. - fn inject_expired_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) { - } + fn inject_expired_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) {} /// A listener experienced an error. fn inject_listener_error(&mut self, _id: ListenerId, _err: &(dyn std::error::Error + 'static)) { } /// A listener closed. - fn inject_listener_closed(&mut self, _id: ListenerId, _reason: Result<(), &std::io::Error>) { - } + fn inject_listener_closed(&mut self, _id: ListenerId, _reason: Result<(), &std::io::Error>) {} /// Indicates to the behaviour that we have discovered a new external address for us. - fn inject_new_external_addr(&mut self, _addr: &Multiaddr) { - } + fn inject_new_external_addr(&mut self, _addr: &Multiaddr) {} /// Indicates to the behaviour that an external address was removed. - fn inject_expired_external_addr(&mut self, _addr: &Multiaddr) { - } + fn inject_expired_external_addr(&mut self, _addr: &Multiaddr) {} /// Polls for things that swarm should do. /// @@ -311,47 +311,71 @@ pub enum NetworkBehaviourAction { peer_id: PeerId, /// Whether to close a specific or all connections to the given peer. connection: CloseConnection, - } + }, } impl NetworkBehaviourAction { /// Map the handler event. pub fn map_in(self, f: impl FnOnce(TInEvent) -> E) -> NetworkBehaviourAction { match self { - NetworkBehaviourAction::GenerateEvent(e) => - NetworkBehaviourAction::GenerateEvent(e), - NetworkBehaviourAction::DialAddress { address } => - NetworkBehaviourAction::DialAddress { address }, - NetworkBehaviourAction::DialPeer { peer_id, condition } => - NetworkBehaviourAction::DialPeer { peer_id, condition }, - NetworkBehaviourAction::NotifyHandler { peer_id, handler, event } => - NetworkBehaviourAction::NotifyHandler { - peer_id, - handler, - event: f(event) - }, - NetworkBehaviourAction::ReportObservedAddr { address, score } => - NetworkBehaviourAction::ReportObservedAddr { address, score }, - NetworkBehaviourAction::CloseConnection { peer_id, connection } => - NetworkBehaviourAction::CloseConnection { peer_id, connection } + NetworkBehaviourAction::GenerateEvent(e) => NetworkBehaviourAction::GenerateEvent(e), + NetworkBehaviourAction::DialAddress { address } => { + NetworkBehaviourAction::DialAddress { address } + } + NetworkBehaviourAction::DialPeer { peer_id, condition } => { + NetworkBehaviourAction::DialPeer { peer_id, condition } + } + NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + } => NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event: f(event), + }, + NetworkBehaviourAction::ReportObservedAddr { address, score } => { + NetworkBehaviourAction::ReportObservedAddr { address, score } + } + NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + } => NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }, } } /// Map the event the swarm will return. pub fn map_out(self, f: impl FnOnce(TOutEvent) -> E) -> NetworkBehaviourAction { match self { - NetworkBehaviourAction::GenerateEvent(e) => - NetworkBehaviourAction::GenerateEvent(f(e)), - NetworkBehaviourAction::DialAddress { address } => - NetworkBehaviourAction::DialAddress { address }, - NetworkBehaviourAction::DialPeer { peer_id, condition } => - NetworkBehaviourAction::DialPeer { peer_id, condition }, - NetworkBehaviourAction::NotifyHandler { peer_id, handler, event } => - NetworkBehaviourAction::NotifyHandler { peer_id, handler, event }, - NetworkBehaviourAction::ReportObservedAddr { address, score } => - NetworkBehaviourAction::ReportObservedAddr { address, score }, - NetworkBehaviourAction::CloseConnection { peer_id, connection } => - NetworkBehaviourAction::CloseConnection { peer_id, connection } + NetworkBehaviourAction::GenerateEvent(e) => NetworkBehaviourAction::GenerateEvent(f(e)), + NetworkBehaviourAction::DialAddress { address } => { + NetworkBehaviourAction::DialAddress { address } + } + NetworkBehaviourAction::DialPeer { peer_id, condition } => { + NetworkBehaviourAction::DialPeer { peer_id, condition } + } + NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + } => NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + }, + NetworkBehaviourAction::ReportObservedAddr { address, score } => { + NetworkBehaviourAction::ReportObservedAddr { address, score } + } + NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + } => NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }, } } } diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index ae23b94be3e..0bc718690be 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -63,62 +63,42 @@ pub mod protocols_handler; pub mod toggle; pub use behaviour::{ - NetworkBehaviour, - NetworkBehaviourAction, - NetworkBehaviourEventProcess, - PollParameters, - NotifyHandler, - DialPeerCondition, - CloseConnection + CloseConnection, DialPeerCondition, NetworkBehaviour, NetworkBehaviourAction, + NetworkBehaviourEventProcess, NotifyHandler, PollParameters, }; pub use protocols_handler::{ - IntoProtocolsHandler, - IntoProtocolsHandlerSelect, - KeepAlive, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerSelect, - ProtocolsHandlerUpgrErr, - OneShotHandler, - OneShotHandlerConfig, - SubstreamProtocol + IntoProtocolsHandler, IntoProtocolsHandlerSelect, KeepAlive, OneShotHandler, + OneShotHandlerConfig, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerSelect, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; -pub use registry::{AddressScore, AddressRecord, AddAddressResult}; - -use protocols_handler::{ - NodeHandlerWrapperBuilder, - NodeHandlerWrapperError, -}; -use futures::{ - prelude::*, - executor::ThreadPoolBuilder, - stream::FusedStream, -}; -use libp2p_core::{Executor, Multiaddr, Negotiated, PeerId, Transport, connection::{ - ConnectionError, - ConnectionId, - ConnectionLimit, - ConnectedPoint, - EstablishedConnection, - ConnectionHandler, - IntoConnectionHandler, - ListenerId, - PendingConnectionError, - Substream - }, muxing::StreamMuxerBox, network::{ - self, - ConnectionLimits, - Network, +pub use registry::{AddAddressResult, AddressRecord, AddressScore}; + +use futures::{executor::ThreadPoolBuilder, prelude::*, stream::FusedStream}; +use libp2p_core::{ + connection::{ + ConnectedPoint, ConnectionError, ConnectionHandler, ConnectionId, ConnectionLimit, + EstablishedConnection, IntoConnectionHandler, ListenerId, PendingConnectionError, + Substream, + }, + muxing::StreamMuxerBox, + network::{ + self, peer::ConnectedPeer, ConnectionLimits, Network, NetworkConfig, NetworkEvent, NetworkInfo, - NetworkEvent, - NetworkConfig, - peer::ConnectedPeer, - }, transport::{self, TransportError}, upgrade::{ProtocolName}}; -use registry::{Addresses, AddressIntoIter}; + }, + transport::{self, TransportError}, + upgrade::ProtocolName, + Executor, Multiaddr, Negotiated, PeerId, Transport, +}; +use protocols_handler::{NodeHandlerWrapperBuilder, NodeHandlerWrapperError}; +use registry::{AddressIntoIter, Addresses}; use smallvec::SmallVec; -use std::{error, fmt, io, pin::Pin, task::{Context, Poll}}; use std::collections::HashSet; use std::num::{NonZeroU32, NonZeroUsize}; +use std::{ + error, fmt, io, + pin::Pin, + task::{Context, Poll}, +}; use upgrade::UpgradeInfoSend as _; /// Substream for which a protocol has been chosen. @@ -136,13 +116,16 @@ type THandler = ::ProtocolsHandler; /// Custom event that can be received by the [`ProtocolsHandler`] of the /// [`NetworkBehaviour`]. -type THandlerInEvent = < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::InEvent; +type THandlerInEvent = + < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::InEvent; /// Custom event that can be produced by the [`ProtocolsHandler`] of the [`NetworkBehaviour`]. -type THandlerOutEvent = < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::OutEvent; +type THandlerOutEvent = + < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::OutEvent; /// Custom error that can be produced by the [`ProtocolsHandler`] of the [`NetworkBehaviour`]. -type THandlerErr = < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::Error; +type THandlerErr = + < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::Error; /// Event generated by the `Swarm`. #[derive(Debug)] @@ -228,18 +211,18 @@ pub enum SwarmEvent { error: PendingConnectionError, }, /// One of our listeners has reported a new local listening address. - NewListenAddr{ + NewListenAddr { /// The listener that is listening on the new address. listener_id: ListenerId, /// The new address that is being listened on. - address: Multiaddr + address: Multiaddr, }, /// One of our listeners has reported the expiration of a listening address. - ExpiredListenAddr{ + ExpiredListenAddr { /// The listener that is no longer listening on the address. listener_id: ListenerId, /// The expired address. - address: Multiaddr + address: Multiaddr, }, /// One of the listeners gracefully closed. ListenerClosed { @@ -308,11 +291,7 @@ where substream_upgrade_protocol_override: Option, } -impl Unpin for Swarm -where - TBehaviour: NetworkBehaviour, -{ -} +impl Unpin for Swarm where TBehaviour: NetworkBehaviour {} impl Swarm where @@ -322,7 +301,7 @@ where pub fn new( transport: transport::Boxed<(PeerId, StreamMuxerBox)>, behaviour: TBehaviour, - local_peer_id: PeerId + local_peer_id: PeerId, ) -> Self { SwarmBuilder::new(transport, behaviour, local_peer_id).build() } @@ -352,7 +331,9 @@ where /// Initiates a new dialing attempt to the given address. pub fn dial_addr(&mut self, addr: Multiaddr) -> Result<(), DialError> { - let handler = self.behaviour.new_handler() + let handler = self + .behaviour + .new_handler() .into_node_handler_builder() .with_substream_upgrade_protocol_override(self.substream_upgrade_protocol_override); Ok(self.network.dial(&addr, handler).map(|_id| ())?) @@ -362,31 +343,37 @@ where pub fn dial(&mut self, peer_id: &PeerId) -> Result<(), DialError> { if self.banned_peers.contains(peer_id) { self.behaviour.inject_dial_failure(peer_id); - return Err(DialError::Banned) + return Err(DialError::Banned); } let self_listening = &self.listened_addrs; - let mut addrs = self.behaviour.addresses_of_peer(peer_id) + let mut addrs = self + .behaviour + .addresses_of_peer(peer_id) .into_iter() .filter(|a| !self_listening.contains(a)); - let result = - if let Some(first) = addrs.next() { - let handler = self.behaviour.new_handler() - .into_node_handler_builder() - .with_substream_upgrade_protocol_override(self.substream_upgrade_protocol_override); - self.network.peer(*peer_id) - .dial(first, addrs, handler) - .map(|_| ()) - .map_err(DialError::from) - } else { - Err(DialError::NoAddresses) - }; + let result = if let Some(first) = addrs.next() { + let handler = self + .behaviour + .new_handler() + .into_node_handler_builder() + .with_substream_upgrade_protocol_override(self.substream_upgrade_protocol_override); + self.network + .peer(*peer_id) + .dial(first, addrs, handler) + .map(|_| ()) + .map_err(DialError::from) + } else { + Err(DialError::NoAddresses) + }; if let Err(error) = &result { log::debug!( "New dialing attempt to peer {:?} failed: {:?}.", - peer_id, error); + peer_id, + error + ); self.behaviour.inject_dial_failure(&peer_id); } @@ -508,9 +495,10 @@ where /// Internal function used by everything event-related. /// /// Polls the `Swarm` for the next event. - fn poll_next_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll>> - { + fn poll_next_event( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { // We use a `this` variable because the compiler can't mutably borrow multiple times // across a `Deref`. let this = &mut *self; @@ -525,38 +513,62 @@ where let peer = connection.peer_id(); let connection = connection.id(); this.behaviour.inject_event(peer, connection, event); - }, - Poll::Ready(NetworkEvent::AddressChange { connection, new_endpoint, old_endpoint }) => { + } + Poll::Ready(NetworkEvent::AddressChange { + connection, + new_endpoint, + old_endpoint, + }) => { let peer = connection.peer_id(); let connection = connection.id(); - this.behaviour.inject_address_change(&peer, &connection, &old_endpoint, &new_endpoint); - }, - Poll::Ready(NetworkEvent::ConnectionEstablished { connection, num_established }) => { + this.behaviour.inject_address_change( + &peer, + &connection, + &old_endpoint, + &new_endpoint, + ); + } + Poll::Ready(NetworkEvent::ConnectionEstablished { + connection, + num_established, + }) => { let peer_id = connection.peer_id(); let endpoint = connection.endpoint().clone(); if this.banned_peers.contains(&peer_id) { - this.network.peer(peer_id) + this.network + .peer(peer_id) .into_connected() .expect("the Network just notified us that we were connected; QED") .disconnect(); - return Poll::Ready(SwarmEvent::BannedPeer { - peer_id, - endpoint, - }); + return Poll::Ready(SwarmEvent::BannedPeer { peer_id, endpoint }); } else { - log::debug!("Connection established: {:?}; Total (peer): {}.", - connection.connected(), num_established); + log::debug!( + "Connection established: {:?}; Total (peer): {}.", + connection.connected(), + num_established + ); let endpoint = connection.endpoint().clone(); - this.behaviour.inject_connection_established(&peer_id, &connection.id(), &endpoint); + this.behaviour.inject_connection_established( + &peer_id, + &connection.id(), + &endpoint, + ); if num_established.get() == 1 { this.behaviour.inject_connected(&peer_id); } return Poll::Ready(SwarmEvent::ConnectionEstablished { - peer_id, num_established, endpoint + peer_id, + num_established, + endpoint, }); } - }, - Poll::Ready(NetworkEvent::ConnectionClosed { id, connected, error, num_established }) => { + } + Poll::Ready(NetworkEvent::ConnectionClosed { + id, + connected, + error, + num_established, + }) => { if let Some(error) = error.as_ref() { log::debug!("Connection {:?} closed: {:?}", connected, error); } else { @@ -564,7 +576,8 @@ where } let peer_id = connected.peer_id; let endpoint = connected.endpoint; - this.behaviour.inject_connection_closed(&peer_id, &id, &endpoint); + this.behaviour + .inject_connection_closed(&peer_id, &id, &endpoint); if num_established == 0 { this.behaviour.inject_disconnected(&peer_id); } @@ -574,11 +587,15 @@ where cause: error, num_established, }); - }, + } Poll::Ready(NetworkEvent::IncomingConnection { connection, .. }) => { - let handler = this.behaviour.new_handler() + let handler = this + .behaviour + .new_handler() .into_node_handler_builder() - .with_substream_upgrade_protocol_override(this.substream_upgrade_protocol_override); + .with_substream_upgrade_protocol_override( + this.substream_upgrade_protocol_override, + ); let local_addr = connection.local_addr.clone(); let send_back_addr = connection.send_back_addr.clone(); if let Err(e) = this.network.accept(connection, handler) { @@ -588,36 +605,55 @@ where local_addr, send_back_addr, }); - }, - Poll::Ready(NetworkEvent::NewListenerAddress { listener_id, listen_addr }) => { + } + Poll::Ready(NetworkEvent::NewListenerAddress { + listener_id, + listen_addr, + }) => { log::debug!("Listener {:?}; New address: {:?}", listener_id, listen_addr); if !this.listened_addrs.contains(&listen_addr) { this.listened_addrs.push(listen_addr.clone()) } - this.behaviour.inject_new_listen_addr(listener_id, &listen_addr); + this.behaviour + .inject_new_listen_addr(listener_id, &listen_addr); return Poll::Ready(SwarmEvent::NewListenAddr { listener_id, - address: listen_addr + address: listen_addr, }); } - Poll::Ready(NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr }) => { - log::debug!("Listener {:?}; Expired address {:?}.", listener_id, listen_addr); + Poll::Ready(NetworkEvent::ExpiredListenerAddress { + listener_id, + listen_addr, + }) => { + log::debug!( + "Listener {:?}; Expired address {:?}.", + listener_id, + listen_addr + ); this.listened_addrs.retain(|a| a != &listen_addr); - this.behaviour.inject_expired_listen_addr(listener_id, &listen_addr); - return Poll::Ready(SwarmEvent::ExpiredListenAddr{ + this.behaviour + .inject_expired_listen_addr(listener_id, &listen_addr); + return Poll::Ready(SwarmEvent::ExpiredListenAddr { listener_id, - address: listen_addr + address: listen_addr, }); } - Poll::Ready(NetworkEvent::ListenerClosed { listener_id, addresses, reason }) => { + Poll::Ready(NetworkEvent::ListenerClosed { + listener_id, + addresses, + reason, + }) => { log::debug!("Listener {:?}; Closed by {:?}.", listener_id, reason); for addr in addresses.iter() { this.behaviour.inject_expired_listen_addr(listener_id, addr); } - this.behaviour.inject_listener_closed(listener_id, match &reason { - Ok(()) => Ok(()), - Err(err) => Err(err), - }); + this.behaviour.inject_listener_closed( + listener_id, + match &reason { + Ok(()) => Ok(()), + Err(err) => Err(err), + }, + ); return Poll::Ready(SwarmEvent::ListenerClosed { listener_id, addresses, @@ -626,24 +662,31 @@ where } Poll::Ready(NetworkEvent::ListenerError { listener_id, error }) => { this.behaviour.inject_listener_error(listener_id, &error); - return Poll::Ready(SwarmEvent::ListenerError { - listener_id, - error, - }); - }, - Poll::Ready(NetworkEvent::IncomingConnectionError { local_addr, send_back_addr, error }) => { + return Poll::Ready(SwarmEvent::ListenerError { listener_id, error }); + } + Poll::Ready(NetworkEvent::IncomingConnectionError { + local_addr, + send_back_addr, + error, + }) => { log::debug!("Incoming connection failed: {:?}", error); return Poll::Ready(SwarmEvent::IncomingConnectionError { local_addr, send_back_addr, error, }); - }, - Poll::Ready(NetworkEvent::DialError { peer_id, multiaddr, error, attempts_remaining }) => { + } + Poll::Ready(NetworkEvent::DialError { + peer_id, + multiaddr, + error, + attempts_remaining, + }) => { log::debug!( "Connection attempt to {:?} via {:?} failed with {:?}. Attempts remaining: {}.", peer_id, multiaddr, error, attempts_remaining); - this.behaviour.inject_addr_reach_failure(Some(&peer_id), &multiaddr, &error); + this.behaviour + .inject_addr_reach_failure(Some(&peer_id), &multiaddr, &error); if attempts_remaining == 0 { this.behaviour.inject_dial_failure(&peer_id); } @@ -653,16 +696,22 @@ where error, attempts_remaining, }); - }, - Poll::Ready(NetworkEvent::UnknownPeerDialError { multiaddr, error, .. }) => { - log::debug!("Connection attempt to address {:?} of unknown peer failed with {:?}", - multiaddr, error); - this.behaviour.inject_addr_reach_failure(None, &multiaddr, &error); + } + Poll::Ready(NetworkEvent::UnknownPeerDialError { + multiaddr, error, .. + }) => { + log::debug!( + "Connection attempt to address {:?} of unknown peer failed with {:?}", + multiaddr, + error + ); + this.behaviour + .inject_addr_reach_failure(None, &multiaddr, &error); return Poll::Ready(SwarmEvent::UnknownPeerUnreachableAddr { address: multiaddr, error, }); - }, + } } // After the network had a chance to make progress, try to deliver @@ -673,18 +722,21 @@ where if let Some((peer_id, handler, event)) = this.pending_event.take() { if let Some(mut peer) = this.network.peer(peer_id).into_connected() { match handler { - PendingNotifyHandler::One(conn_id) => + PendingNotifyHandler::One(conn_id) => { if let Some(mut conn) = peer.connection(conn_id) { if let Some(event) = notify_one(&mut conn, event, cx) { this.pending_event = Some((peer_id, handler, event)); - return Poll::Pending + return Poll::Pending; } - }, + } + } PendingNotifyHandler::Any(ids) => { - if let Some((event, ids)) = notify_any::<_, _, TBehaviour>(ids, &mut peer, event, cx) { + if let Some((event, ids)) = + notify_any::<_, _, TBehaviour>(ids, &mut peer, event, cx) + { let handler = PendingNotifyHandler::Any(ids); this.pending_event = Some((peer_id, handler, event)); - return Poll::Pending + return Poll::Pending; } } } @@ -698,7 +750,7 @@ where local_peer_id: &mut this.network.local_peer_id(), supported_protocols: &this.supported_protocols, listened_addrs: &this.listened_addrs, - external_addrs: &this.external_addrs + external_addrs: &this.external_addrs, }; this.behaviour.poll(cx, &mut parameters) }; @@ -708,29 +760,34 @@ where Poll::Pending => (), Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) => { return Poll::Ready(SwarmEvent::Behaviour(event)) - }, + } Poll::Ready(NetworkBehaviourAction::DialAddress { address }) => { let _ = Swarm::dial_addr(&mut *this, address); - }, + } Poll::Ready(NetworkBehaviourAction::DialPeer { peer_id, condition }) => { if this.banned_peers.contains(&peer_id) { this.behaviour.inject_dial_failure(&peer_id); } else { let condition_matched = match condition { - DialPeerCondition::Disconnected => this.network.is_disconnected(&peer_id), + DialPeerCondition::Disconnected => { + this.network.is_disconnected(&peer_id) + } DialPeerCondition::NotDialing => !this.network.is_dialing(&peer_id), DialPeerCondition::Always => true, }; if condition_matched { if Swarm::dial(this, &peer_id).is_ok() { - return Poll::Ready(SwarmEvent::Dialing(peer_id)) + return Poll::Ready(SwarmEvent::Dialing(peer_id)); } } else { // Even if the condition for a _new_ dialing attempt is not met, // we always add any potentially new addresses of the peer to an // ongoing dialing attempt, if there is one. - log::trace!("Condition for new dialing attempt to {:?} not met: {:?}", - peer_id, condition); + log::trace!( + "Condition for new dialing attempt to {:?} not met: {:?}", + peer_id, + condition + ); let self_listening = &this.listened_addrs; if let Some(mut peer) = this.network.peer(peer_id).into_dialing() { let addrs = this.behaviour.addresses_of_peer(peer.id()); @@ -743,8 +800,12 @@ where } } } - }, - Poll::Ready(NetworkBehaviourAction::NotifyHandler { peer_id, handler, event }) => { + } + Poll::Ready(NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + }) => { if let Some(mut peer) = this.network.peer(peer_id).into_connected() { match handler { NotifyHandler::One(connection) => { @@ -752,27 +813,32 @@ where if let Some(event) = notify_one(&mut conn, event, cx) { let handler = PendingNotifyHandler::One(connection); this.pending_event = Some((peer_id, handler, event)); - return Poll::Pending + return Poll::Pending; } } } NotifyHandler::Any => { let ids = peer.connections().into_ids().collect(); - if let Some((event, ids)) = notify_any::<_, _, TBehaviour>(ids, &mut peer, event, cx) { + if let Some((event, ids)) = + notify_any::<_, _, TBehaviour>(ids, &mut peer, event, cx) + { let handler = PendingNotifyHandler::Any(ids); this.pending_event = Some((peer_id, handler, event)); - return Poll::Pending + return Poll::Pending; } } } } - }, + } Poll::Ready(NetworkBehaviourAction::ReportObservedAddr { address, score }) => { for addr in this.network.address_translation(&address) { this.add_external_address(addr, score); } - }, - Poll::Ready(NetworkBehaviourAction::CloseConnection { peer_id, connection }) => { + } + Poll::Ready(NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }) => { if let Some(mut peer) = this.network.peer(peer_id).into_connected() { match connection { CloseConnection::One(connection_id) => { @@ -785,7 +851,7 @@ where } } } - }, + } } } } @@ -814,8 +880,7 @@ fn notify_one<'a, THandlerInEvent>( conn: &mut EstablishedConnection<'a, THandlerInEvent>, event: THandlerInEvent, cx: &mut Context<'_>, -) -> Option -{ +) -> Option { match conn.poll_ready_notify_handler(cx) { Poll::Pending => Some(event), Poll::Ready(Err(())) => None, // connection is closing @@ -847,7 +912,10 @@ where TTrans: Transport, TBehaviour: NetworkBehaviour, THandler: IntoConnectionHandler, - THandler::Handler: ConnectionHandler, OutEvent = THandlerOutEvent> + THandler::Handler: ConnectionHandler< + InEvent = THandlerInEvent, + OutEvent = THandlerOutEvent, + >, { let mut pending = SmallVec::new(); let mut event = Some(event); // (1) @@ -861,19 +929,20 @@ where if let Err(e) = conn.notify_handler(e) { event = Some(e) // (2) } else { - break + break; } } } } } - event.and_then(|e| + event.and_then(|e| { if !pending.is_empty() { Some((e, pending)) } else { None - }) + } + }) } /// Stream of events returned by [`Swarm`]. @@ -890,9 +959,7 @@ where type Item = SwarmEvent, THandlerErr>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.as_mut() - .poll_next_event(cx) - .map(Some) + self.as_mut().poll_next_event(cx).map(Some) } } @@ -957,7 +1024,7 @@ where pub fn new( transport: transport::Boxed<(PeerId, StreamMuxerBox)>, behaviour: TBehaviour, - local_peer_id: PeerId + local_peer_id: PeerId, ) -> Self { SwarmBuilder { local_peer_id, @@ -1042,7 +1109,8 @@ where /// Builds a `Swarm` with the current configuration. pub fn build(mut self) -> Swarm { - let supported_protocols = self.behaviour + let supported_protocols = self + .behaviour .new_handler() .inbound_protocol() .protocol_info() @@ -1051,20 +1119,19 @@ where .collect(); // If no executor has been explicitly configured, try to set up a thread pool. - let network_cfg = self.network_config.or_else_with_executor(|| { - match ThreadPoolBuilder::new() - .name_prefix("libp2p-swarm-task-") - .create() - { - Ok(tp) => { - Some(Box::new(move |f| tp.spawn_ok(f))) - }, - Err(err) => { - log::warn!("Failed to create executor thread pool: {:?}", err); - None + let network_cfg = + self.network_config.or_else_with_executor(|| { + match ThreadPoolBuilder::new() + .name_prefix("libp2p-swarm-task-") + .create() + { + Ok(tp) => Some(Box::new(move |f| tp.spawn_ok(f))), + Err(err) => { + log::warn!("Failed to create executor thread pool: {:?}", err); + None + } } - } - }); + }); let network = Network::new(self.transport, self.local_peer_id, network_cfg); @@ -1093,7 +1160,7 @@ pub enum DialError { InvalidAddress(Multiaddr), /// [`NetworkBehaviour::addresses_of_peer`] returned no addresses /// for the peer to dial. - NoAddresses + NoAddresses, } impl From for DialError { @@ -1111,7 +1178,7 @@ impl fmt::Display for DialError { DialError::ConnectionLimit(err) => write!(f, "Dial error: {}", err), DialError::NoAddresses => write!(f, "Dial error: no addresses for peer."), DialError::InvalidAddress(a) => write!(f, "Dial error: invalid address: {}", a), - DialError::Banned => write!(f, "Dial error: peer is banned.") + DialError::Banned => write!(f, "Dial error: peer is banned."), } } } @@ -1122,7 +1189,7 @@ impl error::Error for DialError { DialError::ConnectionLimit(err) => Some(err), DialError::InvalidAddress(_) => None, DialError::NoAddresses => None, - DialError::Banned => None + DialError::Banned => None, } } } @@ -1130,14 +1197,12 @@ impl error::Error for DialError { /// Dummy implementation of [`NetworkBehaviour`] that doesn't do anything. #[derive(Clone)] pub struct DummyBehaviour { - keep_alive: KeepAlive + keep_alive: KeepAlive, } impl DummyBehaviour { pub fn with_keep_alive(keep_alive: KeepAlive) -> Self { - Self { - keep_alive - } + Self { keep_alive } } pub fn keep_alive_mut(&mut self) -> &mut KeepAlive { @@ -1148,7 +1213,7 @@ impl DummyBehaviour { impl Default for DummyBehaviour { fn default() -> Self { Self { - keep_alive: KeepAlive::No + keep_alive: KeepAlive::No, } } } @@ -1159,7 +1224,7 @@ impl NetworkBehaviour for DummyBehaviour { fn new_handler(&mut self) -> Self::ProtocolsHandler { protocols_handler::DummyProtocolsHandler { - keep_alive: self.keep_alive + keep_alive: self.keep_alive, } } @@ -1167,32 +1232,33 @@ impl NetworkBehaviour for DummyBehaviour { &mut self, _: PeerId, _: ConnectionId, - event: ::OutEvent + event: ::OutEvent, ) { void::unreachable(event) } - fn poll(&mut self, _: &mut Context<'_>, _: &mut impl PollParameters) -> - Poll::InEvent, Self::OutEvent>> - { + fn poll( + &mut self, + _: &mut Context<'_>, + _: &mut impl PollParameters, + ) -> Poll< + NetworkBehaviourAction< + ::InEvent, + Self::OutEvent, + >, + > { Poll::Pending } } #[cfg(test)] mod tests { + use super::*; use crate::protocols_handler::DummyProtocolsHandler; - use crate::test::{MockBehaviour, CallTraceBehaviour}; - use futures::{future, executor}; - use libp2p_core::{ - identity, - upgrade, - multiaddr, - transport - }; + use crate::test::{CallTraceBehaviour, MockBehaviour}; + use futures::{executor, future}; + use libp2p_core::{identity, multiaddr, transport, upgrade}; use libp2p_noise as noise; - use super::*; // Test execution state. // Connection => Disconnecting => Connecting. @@ -1205,11 +1271,13 @@ mod tests { where T: ProtocolsHandler + Clone, T::OutEvent: Clone, - O: Send + 'static + O: Send + 'static, { let id_keys = identity::Keypair::generate_ed25519(); let pubkey = id_keys.public(); - let noise_keys = noise::Keypair::::new().into_authentic(&id_keys).unwrap(); + let noise_keys = noise::Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); let transport = transport::MemoryTransport::default() .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) @@ -1275,7 +1343,9 @@ mod tests { fn test_connect_disconnect_ban() { // Since the test does not try to open any substreams, we can // use the dummy protocols handler. - let handler_proto = DummyProtocolsHandler { keep_alive: KeepAlive::Yes }; + let handler_proto = DummyProtocolsHandler { + keep_alive: KeepAlive::Yes, + }; let mut swarm1 = new_test_swarm::<_, ()>(handler_proto.clone()); let mut swarm2 = new_test_swarm::<_, ()>(handler_proto); @@ -1306,7 +1376,7 @@ mod tests { State::Connecting => { if swarms_connected(&swarm1, &swarm2, num_connections) { if banned { - return Poll::Ready(()) + return Poll::Ready(()); } swarm2.ban_peer_id(swarm1_id.clone()); swarm1.behaviour.reset(); @@ -1318,7 +1388,7 @@ mod tests { State::Disconnecting => { if swarms_disconnected(&swarm1, &swarm2, num_connections) { if unbanned { - return Poll::Ready(()) + return Poll::Ready(()); } // Unban the first peer and reconnect. swarm2.unban_peer_id(swarm1_id.clone()); @@ -1334,7 +1404,7 @@ mod tests { } if poll1.is_pending() && poll2.is_pending() { - return Poll::Pending + return Poll::Pending; } } })) @@ -1350,7 +1420,9 @@ mod tests { fn test_swarm_disconnect() { // Since the test does not try to open any substreams, we can // use the dummy protocols handler. - let handler_proto = DummyProtocolsHandler { keep_alive: KeepAlive::Yes }; + let handler_proto = DummyProtocolsHandler { + keep_alive: KeepAlive::Yes, + }; let mut swarm1 = new_test_swarm::<_, ()>(handler_proto.clone()); let mut swarm2 = new_test_swarm::<_, ()>(handler_proto); @@ -1371,41 +1443,41 @@ mod tests { } let mut state = State::Connecting; - executor::block_on(future::poll_fn(move |cx| { - loop { - let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); - let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); - match state { - State::Connecting => { - if swarms_connected(&swarm1, &swarm2, num_connections) { - if reconnected { - return Poll::Ready(()) - } - swarm2.disconnect_peer_id(swarm1_id.clone()).expect("Error disconnecting"); - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - state = State::Disconnecting; + executor::block_on(future::poll_fn(move |cx| loop { + let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); + let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); + match state { + State::Connecting => { + if swarms_connected(&swarm1, &swarm2, num_connections) { + if reconnected { + return Poll::Ready(()); } + swarm2 + .disconnect_peer_id(swarm1_id.clone()) + .expect("Error disconnecting"); + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + state = State::Disconnecting; } - State::Disconnecting => { - if swarms_disconnected(&swarm1, &swarm2, num_connections) { - if reconnected { - return Poll::Ready(()) - } - reconnected = true; - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - for _ in 0..num_connections { - swarm2.dial_addr(addr1.clone()).unwrap(); - } - state = State::Connecting; + } + State::Disconnecting => { + if swarms_disconnected(&swarm1, &swarm2, num_connections) { + if reconnected { + return Poll::Ready(()); + } + reconnected = true; + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + for _ in 0..num_connections { + swarm2.dial_addr(addr1.clone()).unwrap(); } + state = State::Connecting; } } + } - if poll1.is_pending() && poll2.is_pending() { - return Poll::Pending - } + if poll1.is_pending() && poll2.is_pending() { + return Poll::Pending; } })) } @@ -1421,7 +1493,9 @@ mod tests { fn test_behaviour_disconnect_all() { // Since the test does not try to open any substreams, we can // use the dummy protocols handler. - let handler_proto = DummyProtocolsHandler { keep_alive: KeepAlive::Yes }; + let handler_proto = DummyProtocolsHandler { + keep_alive: KeepAlive::Yes, + }; let mut swarm1 = new_test_swarm::<_, ()>(handler_proto.clone()); let mut swarm2 = new_test_swarm::<_, ()>(handler_proto); @@ -1442,48 +1516,44 @@ mod tests { } let mut state = State::Connecting; - executor::block_on(future::poll_fn(move |cx| { - loop { - let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); - let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); - match state { - State::Connecting => { - if swarms_connected(&swarm1, &swarm2, num_connections) { - if reconnected { - return Poll::Ready(()) - } - swarm2 - .behaviour - .inner() - .next_action - .replace(NetworkBehaviourAction::CloseConnection { - peer_id: swarm1_id.clone(), - connection: CloseConnection::All, - }); - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - state = State::Disconnecting; + executor::block_on(future::poll_fn(move |cx| loop { + let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); + let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); + match state { + State::Connecting => { + if swarms_connected(&swarm1, &swarm2, num_connections) { + if reconnected { + return Poll::Ready(()); } + swarm2.behaviour.inner().next_action.replace( + NetworkBehaviourAction::CloseConnection { + peer_id: swarm1_id.clone(), + connection: CloseConnection::All, + }, + ); + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + state = State::Disconnecting; } - State::Disconnecting => { - if swarms_disconnected(&swarm1, &swarm2, num_connections) { - if reconnected { - return Poll::Ready(()) - } - reconnected = true; - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - for _ in 0..num_connections { - swarm2.dial_addr(addr1.clone()).unwrap(); - } - state = State::Connecting; + } + State::Disconnecting => { + if swarms_disconnected(&swarm1, &swarm2, num_connections) { + if reconnected { + return Poll::Ready(()); } + reconnected = true; + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + for _ in 0..num_connections { + swarm2.dial_addr(addr1.clone()).unwrap(); + } + state = State::Connecting; } } + } - if poll1.is_pending() && poll2.is_pending() { - return Poll::Pending - } + if poll1.is_pending() && poll2.is_pending() { + return Poll::Pending; } })) } @@ -1499,7 +1569,9 @@ mod tests { fn test_behaviour_disconnect_one() { // Since the test does not try to open any substreams, we can // use the dummy protocols handler. - let handler_proto = DummyProtocolsHandler { keep_alive: KeepAlive::Yes }; + let handler_proto = DummyProtocolsHandler { + keep_alive: KeepAlive::Yes, + }; let mut swarm1 = new_test_swarm::<_, ()>(handler_proto.clone()); let mut swarm2 = new_test_swarm::<_, ()>(handler_proto); @@ -1520,49 +1592,48 @@ mod tests { let mut state = State::Connecting; let mut disconnected_conn_id = None; - executor::block_on(future::poll_fn(move |cx| { - loop { - let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); - let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); - match state { - State::Connecting => { - if swarms_connected(&swarm1, &swarm2, num_connections) { - disconnected_conn_id = { - let conn_id = swarm2.behaviour.inject_connection_established[num_connections / 2].1; - swarm2 - .behaviour - .inner() - .next_action - .replace(NetworkBehaviourAction::CloseConnection { - peer_id: swarm1_id.clone(), - connection: CloseConnection::One(conn_id), - }); - Some(conn_id) - }; - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - state = State::Disconnecting; - } + executor::block_on(future::poll_fn(move |cx| loop { + let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); + let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); + match state { + State::Connecting => { + if swarms_connected(&swarm1, &swarm2, num_connections) { + disconnected_conn_id = { + let conn_id = swarm2.behaviour.inject_connection_established + [num_connections / 2] + .1; + swarm2.behaviour.inner().next_action.replace( + NetworkBehaviourAction::CloseConnection { + peer_id: swarm1_id.clone(), + connection: CloseConnection::One(conn_id), + }, + ); + Some(conn_id) + }; + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + state = State::Disconnecting; } - State::Disconnecting => { - for s in &[&swarm1, &swarm2] { - assert_eq!(s.behaviour.inject_disconnected.len(), 0); - assert_eq!(s.behaviour.inject_connection_established.len(), 0); - assert_eq!(s.behaviour.inject_connected.len(), 0); - } - if [&swarm1, &swarm2].iter().all(|s| { - s.behaviour.inject_connection_closed.len() == 1 - }) { - let conn_id = swarm2.behaviour.inject_connection_closed[0].1; - assert_eq!(Some(conn_id), disconnected_conn_id); - return Poll::Ready(()); - } + } + State::Disconnecting => { + for s in &[&swarm1, &swarm2] { + assert_eq!(s.behaviour.inject_disconnected.len(), 0); + assert_eq!(s.behaviour.inject_connection_established.len(), 0); + assert_eq!(s.behaviour.inject_connected.len(), 0); + } + if [&swarm1, &swarm2] + .iter() + .all(|s| s.behaviour.inject_connection_closed.len() == 1) + { + let conn_id = swarm2.behaviour.inject_connection_closed[0].1; + assert_eq!(Some(conn_id), disconnected_conn_id); + return Poll::Ready(()); } } + } - if poll1.is_pending() && poll2.is_pending() { - return Poll::Pending - } + if poll1.is_pending() && poll2.is_pending() { + return Poll::Pending; } })) } diff --git a/swarm/src/protocols_handler.rs b/swarm/src/protocols_handler.rs index 58ae351673b..911693f32af 100644 --- a/swarm/src/protocols_handler.rs +++ b/swarm/src/protocols_handler.rs @@ -40,23 +40,14 @@ mod dummy; mod map_in; mod map_out; +pub mod multi; mod node_handler; mod one_shot; mod select; -pub mod multi; -pub use crate::upgrade::{ - InboundUpgradeSend, - OutboundUpgradeSend, - UpgradeInfoSend, -}; - -use libp2p_core::{ - ConnectedPoint, - Multiaddr, - PeerId, - upgrade::UpgradeError, -}; +pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend}; + +use libp2p_core::{upgrade::UpgradeError, ConnectedPoint, Multiaddr, PeerId}; use std::{cmp::Ordering, error, fmt, task::Context, task::Poll, time::Duration}; use wasm_timer::Instant; @@ -128,7 +119,7 @@ pub trait ProtocolsHandler: Send + 'static { fn inject_fully_negotiated_inbound( &mut self, protocol: ::Output, - info: Self::InboundOpenInfo + info: Self::InboundOpenInfo, ); /// Injects the output of a successful upgrade on a new outbound substream. @@ -138,7 +129,7 @@ pub trait ProtocolsHandler: Send + 'static { fn inject_fully_negotiated_outbound( &mut self, protocol: ::Output, - info: Self::OutboundOpenInfo + info: Self::OutboundOpenInfo, ); /// Injects an event coming from the outside in the handler. @@ -151,17 +142,16 @@ pub trait ProtocolsHandler: Send + 'static { fn inject_dial_upgrade_error( &mut self, info: Self::OutboundOpenInfo, - error: ProtocolsHandlerUpgrErr< - ::Error - > + error: ProtocolsHandlerUpgrErr<::Error>, ); /// Indicates to the handler that upgrading an inbound substream to the given protocol has failed. fn inject_listen_upgrade_error( &mut self, _: Self::InboundOpenInfo, - _: ProtocolsHandlerUpgrErr<::Error> - ) {} + _: ProtocolsHandlerUpgrErr<::Error>, + ) { + } /// Returns until when the connection should be kept alive. /// @@ -186,8 +176,16 @@ pub trait ProtocolsHandler: Send + 'static { fn connection_keep_alive(&self) -> KeepAlive; /// Should behave like `Stream::poll()`. - fn poll(&mut self, cx: &mut Context<'_>) -> Poll< - ProtocolsHandlerEvent + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, >; /// Adds a closure that turns the input event into something else. @@ -315,7 +313,7 @@ pub enum ProtocolsHandlerEvent + protocol: SubstreamProtocol, }, /// Close the connection for the given reason. @@ -341,7 +339,7 @@ impl match self { ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } => { ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: protocol.map_info(map) + protocol: protocol.map_info(map), } } ProtocolsHandlerEvent::Custom(val) => ProtocolsHandlerEvent::Custom(val), @@ -361,7 +359,7 @@ impl match self { ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } => { ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: protocol.map_upgrade(map) + protocol: protocol.map_upgrade(map), } } ProtocolsHandlerEvent::Custom(val) => ProtocolsHandlerEvent::Custom(val), @@ -419,12 +417,12 @@ impl ProtocolsHandlerUpgrErr { /// Map the inner [`UpgradeError`] type. pub fn map_upgrade_err(self, f: F) -> ProtocolsHandlerUpgrErr where - F: FnOnce(UpgradeError) -> UpgradeError + F: FnOnce(UpgradeError) -> UpgradeError, { match self { ProtocolsHandlerUpgrErr::Timeout => ProtocolsHandlerUpgrErr::Timeout, ProtocolsHandlerUpgrErr::Timer => ProtocolsHandlerUpgrErr::Timer, - ProtocolsHandlerUpgrErr::Upgrade(e) => ProtocolsHandlerUpgrErr::Upgrade(f(e)) + ProtocolsHandlerUpgrErr::Upgrade(e) => ProtocolsHandlerUpgrErr::Upgrade(f(e)), } } } @@ -437,10 +435,10 @@ where match self { ProtocolsHandlerUpgrErr::Timeout => { write!(f, "Timeout error while opening a substream") - }, + } ProtocolsHandlerUpgrErr::Timer => { write!(f, "Timer error while opening a substream") - }, + } ProtocolsHandlerUpgrErr::Upgrade(err) => write!(f, "{}", err), } } @@ -448,7 +446,7 @@ where impl error::Error for ProtocolsHandlerUpgrErr where - TUpgrErr: error::Error + 'static + TUpgrErr: error::Error + 'static, { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { @@ -467,7 +465,11 @@ pub trait IntoProtocolsHandler: Send + 'static { /// Builds the protocols handler. /// /// The `PeerId` is the id of the node the handler is going to handle. - fn into_handler(self, remote_peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler; + fn into_handler( + self, + remote_peer_id: &PeerId, + connected_point: &ConnectedPoint, + ) -> Self::Handler; /// Return the handler's inbound protocol. fn inbound_protocol(&self) -> ::InboundProtocol; @@ -492,7 +494,8 @@ pub trait IntoProtocolsHandler: Send + 'static { } impl IntoProtocolsHandler for T -where T: ProtocolsHandler +where + T: ProtocolsHandler, { type Handler = Self; @@ -537,9 +540,9 @@ impl Ord for KeepAlive { use self::KeepAlive::*; match (self, other) { - (No, No) | (Yes, Yes) => Ordering::Equal, - (No, _) | (_, Yes) => Ordering::Less, - (_, No) | (Yes, _) => Ordering::Greater, + (No, No) | (Yes, Yes) => Ordering::Equal, + (No, _) | (_, Yes) => Ordering::Less, + (_, No) | (Yes, _) => Ordering::Greater, (Until(t1), Until(t2)) => t1.cmp(t2), } } diff --git a/swarm/src/protocols_handler/dummy.rs b/swarm/src/protocols_handler/dummy.rs index 764f95fe2cf..97dd55ce793 100644 --- a/swarm/src/protocols_handler/dummy.rs +++ b/swarm/src/protocols_handler/dummy.rs @@ -18,15 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::NegotiatedSubstream; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, +}; +use crate::NegotiatedSubstream; +use libp2p_core::{ + upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade}, + Multiaddr, }; -use libp2p_core::{Multiaddr, upgrade::{InboundUpgrade, OutboundUpgrade, DeniedUpgrade}}; use std::task::{Context, Poll}; use void::Void; @@ -39,7 +38,7 @@ pub struct DummyProtocolsHandler { impl Default for DummyProtocolsHandler { fn default() -> Self { DummyProtocolsHandler { - keep_alive: KeepAlive::No + keep_alive: KeepAlive::No, } } } @@ -60,14 +59,14 @@ impl ProtocolsHandler for DummyProtocolsHandler { fn inject_fully_negotiated_inbound( &mut self, _: >::Output, - _: Self::InboundOpenInfo + _: Self::InboundOpenInfo, ) { } fn inject_fully_negotiated_outbound( &mut self, _: >::Output, - _: Self::OutboundOpenInfo + _: Self::OutboundOpenInfo, ) { } @@ -75,9 +74,23 @@ impl ProtocolsHandler for DummyProtocolsHandler { fn inject_address_change(&mut self, _: &Multiaddr) {} - fn inject_dial_upgrade_error(&mut self, _: Self::OutboundOpenInfo, _: ProtocolsHandlerUpgrErr<>::Error>) {} + fn inject_dial_upgrade_error( + &mut self, + _: Self::OutboundOpenInfo, + _: ProtocolsHandlerUpgrErr< + >::Error, + >, + ) { + } - fn inject_listen_upgrade_error(&mut self, _: Self::InboundOpenInfo, _: ProtocolsHandlerUpgrErr<>::Error>) {} + fn inject_listen_upgrade_error( + &mut self, + _: Self::InboundOpenInfo, + _: ProtocolsHandlerUpgrErr< + >::Error, + >, + ) { + } fn connection_keep_alive(&self) -> KeepAlive { self.keep_alive @@ -87,7 +100,12 @@ impl ProtocolsHandler for DummyProtocolsHandler { &mut self, _: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { Poll::Pending } diff --git a/swarm/src/protocols_handler/map_in.rs b/swarm/src/protocols_handler/map_in.rs index 77ac5f912d9..1c1e436e42d 100644 --- a/swarm/src/protocols_handler/map_in.rs +++ b/swarm/src/protocols_handler/map_in.rs @@ -18,14 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use libp2p_core::Multiaddr; use std::{fmt::Debug, marker::PhantomData, task::Context, task::Poll}; @@ -69,7 +65,7 @@ where fn inject_fully_negotiated_inbound( &mut self, protocol: ::Output, - info: Self::InboundOpenInfo + info: Self::InboundOpenInfo, ) { self.inner.inject_fully_negotiated_inbound(protocol, info) } @@ -77,7 +73,7 @@ where fn inject_fully_negotiated_outbound( &mut self, protocol: ::Output, - info: Self::OutboundOpenInfo + info: Self::OutboundOpenInfo, ) { self.inner.inject_fully_negotiated_outbound(protocol, info) } @@ -92,11 +88,19 @@ where self.inner.inject_address_change(addr) } - fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { self.inner.inject_dial_upgrade_error(info, error) } - fn inject_listen_upgrade_error(&mut self, info: Self::InboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_listen_upgrade_error( + &mut self, + info: Self::InboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { self.inner.inject_listen_upgrade_error(info, error) } @@ -108,7 +112,12 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { self.inner.poll(cx) } diff --git a/swarm/src/protocols_handler/map_out.rs b/swarm/src/protocols_handler/map_out.rs index 9df2ace9256..77d0e1eac93 100644 --- a/swarm/src/protocols_handler/map_out.rs +++ b/swarm/src/protocols_handler/map_out.rs @@ -18,14 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use libp2p_core::Multiaddr; use std::fmt::Debug; use std::task::{Context, Poll}; @@ -39,10 +35,7 @@ pub struct MapOutEvent { impl MapOutEvent { /// Creates a `MapOutEvent`. pub(crate) fn new(inner: TProtoHandler, map: TMap) -> Self { - MapOutEvent { - inner, - map, - } + MapOutEvent { inner, map } } } @@ -68,7 +61,7 @@ where fn inject_fully_negotiated_inbound( &mut self, protocol: ::Output, - info: Self::InboundOpenInfo + info: Self::InboundOpenInfo, ) { self.inner.inject_fully_negotiated_inbound(protocol, info) } @@ -76,7 +69,7 @@ where fn inject_fully_negotiated_outbound( &mut self, protocol: ::Output, - info: Self::OutboundOpenInfo + info: Self::OutboundOpenInfo, ) { self.inner.inject_fully_negotiated_outbound(protocol, info) } @@ -89,11 +82,19 @@ where self.inner.inject_address_change(addr) } - fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { self.inner.inject_dial_upgrade_error(info, error) } - fn inject_listen_upgrade_error(&mut self, info: Self::InboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_listen_upgrade_error( + &mut self, + info: Self::InboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { self.inner.inject_listen_upgrade_error(info, error) } @@ -105,15 +106,18 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { - self.inner.poll(cx).map(|ev| { - match ev { - ProtocolsHandlerEvent::Custom(ev) => ProtocolsHandlerEvent::Custom((self.map)(ev)), - ProtocolsHandlerEvent::Close(err) => ProtocolsHandlerEvent::Close(err), - ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } => { - ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } - } + self.inner.poll(cx).map(|ev| match ev { + ProtocolsHandlerEvent::Custom(ev) => ProtocolsHandlerEvent::Custom((self.map)(ev)), + ProtocolsHandlerEvent::Close(err) => ProtocolsHandlerEvent::Close(err), + ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } => { + ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } } }) } diff --git a/swarm/src/protocols_handler/multi.rs b/swarm/src/protocols_handler/multi.rs index 64821ca3d35..f865443766c 100644 --- a/swarm/src/protocols_handler/multi.rs +++ b/swarm/src/protocols_handler/multi.rs @@ -21,23 +21,15 @@ //! A [`ProtocolsHandler`] implementation that combines multiple other `ProtocolsHandler`s //! indexed by some key. -use crate::NegotiatedSubstream; use crate::protocols_handler::{ - KeepAlive, - IntoProtocolsHandler, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, - SubstreamProtocol -}; -use crate::upgrade::{ - InboundUpgradeSend, - OutboundUpgradeSend, - UpgradeInfoSend + IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend}; +use crate::NegotiatedSubstream; use futures::{future::BoxFuture, prelude::*}; +use libp2p_core::upgrade::{NegotiationError, ProtocolError, ProtocolName, UpgradeError}; use libp2p_core::{ConnectedPoint, Multiaddr, PeerId}; -use libp2p_core::upgrade::{ProtocolName, UpgradeError, NegotiationError, ProtocolError}; use rand::Rng; use std::{ cmp, @@ -47,19 +39,19 @@ use std::{ hash::Hash, iter::{self, FromIterator}, task::{Context, Poll}, - time::Duration + time::Duration, }; /// A [`ProtocolsHandler`] for multiple `ProtocolsHandler`s of the same type. #[derive(Clone)] pub struct MultiHandler { - handlers: HashMap + handlers: HashMap, } impl fmt::Debug for MultiHandler where K: fmt::Debug + Eq + Hash, - H: fmt::Debug + H: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MultiHandler") @@ -71,17 +63,23 @@ where impl MultiHandler where K: Hash + Eq, - H: ProtocolsHandler + H: ProtocolsHandler, { /// Create and populate a `MultiHandler` from the given handler iterator. /// /// It is an error for any two protocols handlers to share the same protocol name. pub fn try_from_iter(iter: I) -> Result where - I: IntoIterator + I: IntoIterator, { - let m = MultiHandler { handlers: HashMap::from_iter(iter) }; - uniq_proto_names(m.handlers.values().map(|h| h.listen_protocol().into_upgrade().0))?; + let m = MultiHandler { + handlers: HashMap::from_iter(iter), + }; + uniq_proto_names( + m.handlers + .values() + .map(|h| h.listen_protocol().into_upgrade().0), + )?; Ok(m) } } @@ -91,7 +89,7 @@ where K: Clone + Debug + Hash + Eq + Send + 'static, H: ProtocolsHandler, H::InboundProtocol: InboundUpgradeSend, - H::OutboundProtocol: OutboundUpgradeSend + H::OutboundProtocol: OutboundUpgradeSend, { type InEvent = (K, ::InEvent); type OutEvent = (K, ::OutEvent); @@ -102,28 +100,31 @@ where type OutboundOpenInfo = (K, ::OutboundOpenInfo); fn listen_protocol(&self) -> SubstreamProtocol { - let (upgrade, info, timeout) = self.handlers.iter() + let (upgrade, info, timeout) = self + .handlers + .iter() .map(|(key, handler)| { let proto = handler.listen_protocol(); let timeout = *proto.timeout(); let (upgrade, info) = proto.into_upgrade(); (key.clone(), (upgrade, info, timeout)) }) - .fold((Upgrade::new(), Info::new(), Duration::from_secs(0)), + .fold( + (Upgrade::new(), Info::new(), Duration::from_secs(0)), |(mut upg, mut inf, mut timeout), (k, (u, i, t))| { upg.upgrades.push((k.clone(), u)); inf.infos.push((k, i)); timeout = cmp::max(timeout, t); (upg, inf, timeout) - } + }, ); SubstreamProtocol::new(upgrade, info).with_timeout(timeout) } - fn inject_fully_negotiated_outbound ( + fn inject_fully_negotiated_outbound( &mut self, protocol: ::Output, - (key, arg): Self::OutboundOpenInfo + (key, arg): Self::OutboundOpenInfo, ) { if let Some(h) = self.handlers.get_mut(&key) { h.inject_fully_negotiated_outbound(protocol, arg) @@ -132,10 +133,10 @@ where } } - fn inject_fully_negotiated_inbound ( + fn inject_fully_negotiated_inbound( &mut self, (key, arg): ::Output, - mut info: Self::InboundOpenInfo + mut info: Self::InboundOpenInfo, ) { if let Some(h) = self.handlers.get_mut(&key) { if let Some(i) = info.take(&key) { @@ -160,10 +161,10 @@ where } } - fn inject_dial_upgrade_error ( + fn inject_dial_upgrade_error( &mut self, (key, arg): Self::OutboundOpenInfo, - error: ProtocolsHandlerUpgrErr<::Error> + error: ProtocolsHandlerUpgrErr<::Error>, ) { if let Some(h) = self.handlers.get_mut(&key) { h.inject_dial_upgrade_error(arg, error) @@ -175,77 +176,118 @@ where fn inject_listen_upgrade_error( &mut self, mut info: Self::InboundOpenInfo, - error: ProtocolsHandlerUpgrErr<::Error> + error: ProtocolsHandlerUpgrErr<::Error>, ) { match error { - ProtocolsHandlerUpgrErr::Timer => + ProtocolsHandlerUpgrErr::Timer => { for (k, h) in &mut self.handlers { if let Some(i) = info.take(k) { h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timer) } } - ProtocolsHandlerUpgrErr::Timeout => + } + ProtocolsHandlerUpgrErr::Timeout => { for (k, h) in &mut self.handlers { if let Some(i) = info.take(k) { h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timeout) } } - ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => + } + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { for (k, h) in &mut self.handlers { if let Some(i) = info.take(k) { - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed))) + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::Failed, + )), + ) } } - ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) => - match e { - ProtocolError::IoError(e) => - for (k, h) in &mut self.handlers { - if let Some(i) = info.take(k) { - let e = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into())); - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e))) - } + } + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::ProtocolError(e), + )) => match e { + ProtocolError::IoError(e) => { + for (k, h) in &mut self.handlers { + if let Some(i) = info.take(k) { + let e = NegotiationError::ProtocolError(ProtocolError::IoError( + e.kind().into(), + )); + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)), + ) } - ProtocolError::InvalidMessage => - for (k, h) in &mut self.handlers { - if let Some(i) = info.take(k) { - let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage); - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e))) - } + } + } + ProtocolError::InvalidMessage => { + for (k, h) in &mut self.handlers { + if let Some(i) = info.take(k) { + let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage); + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)), + ) } - ProtocolError::InvalidProtocol => - for (k, h) in &mut self.handlers { - if let Some(i) = info.take(k) { - let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol); - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e))) - } + } + } + ProtocolError::InvalidProtocol => { + for (k, h) in &mut self.handlers { + if let Some(i) = info.take(k) { + let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol); + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)), + ) } - ProtocolError::TooManyProtocols => - for (k, h) in &mut self.handlers { - if let Some(i) = info.take(k) { - let e = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols); - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e))) - } + } + } + ProtocolError::TooManyProtocols => { + for (k, h) in &mut self.handlers { + if let Some(i) = info.take(k) { + let e = + NegotiationError::ProtocolError(ProtocolError::TooManyProtocols); + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)), + ) } + } } - ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) => + }, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) => { if let Some(h) = self.handlers.get_mut(&k) { if let Some(i) = info.take(&k) { - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e))) + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)), + ) } } + } } } fn connection_keep_alive(&self) -> KeepAlive { - self.handlers.values() + self.handlers + .values() .map(|h| h.connection_keep_alive()) .max() .unwrap_or(KeepAlive::No) } - fn poll(&mut self, cx: &mut Context<'_>) - -> Poll> - { + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, + > { // Calling `gen_range(0, 0)` (see below) would panic, so we have return early to avoid // that situation. if self.handlers.is_empty() { @@ -257,15 +299,19 @@ where for (k, h) in self.handlers.iter_mut().skip(pos) { if let Poll::Ready(e) = h.poll(cx) { - let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p)); - return Poll::Ready(e) + let e = e + .map_outbound_open_info(|i| (k.clone(), i)) + .map_custom(|p| (k.clone(), p)); + return Poll::Ready(e); } } for (k, h) in self.handlers.iter_mut().take(pos) { if let Poll::Ready(e) = h.poll(cx) { - let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p)); - return Poll::Ready(e) + let e = e + .map_outbound_open_info(|i| (k.clone(), i)) + .map_custom(|p| (k.clone(), p)); + return Poll::Ready(e); } } @@ -276,13 +322,13 @@ where /// A [`IntoProtocolsHandler`] for multiple other `IntoProtocolsHandler`s. #[derive(Clone)] pub struct IntoMultiHandler { - handlers: HashMap + handlers: HashMap, } impl fmt::Debug for IntoMultiHandler where K: fmt::Debug + Eq + Hash, - H: fmt::Debug + H: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IntoMultiHandler") @@ -291,20 +337,21 @@ where } } - impl IntoMultiHandler where K: Hash + Eq, - H: IntoProtocolsHandler + H: IntoProtocolsHandler, { /// Create and populate an `IntoMultiHandler` from the given iterator. /// /// It is an error for any two protocols handlers to share the same protocol name. pub fn try_from_iter(iter: I) -> Result where - I: IntoIterator + I: IntoIterator, { - let m = IntoMultiHandler { handlers: HashMap::from_iter(iter) }; + let m = IntoMultiHandler { + handlers: HashMap::from_iter(iter), + }; uniq_proto_names(m.handlers.values().map(|h| h.inbound_protocol()))?; Ok(m) } @@ -313,23 +360,27 @@ where impl IntoProtocolsHandler for IntoMultiHandler where K: Debug + Clone + Eq + Hash + Send + 'static, - H: IntoProtocolsHandler + H: IntoProtocolsHandler, { type Handler = MultiHandler; fn into_handler(self, p: &PeerId, c: &ConnectedPoint) -> Self::Handler { MultiHandler { - handlers: self.handlers.into_iter() + handlers: self + .handlers + .into_iter() .map(|(k, h)| (k, h.into_handler(p, c))) - .collect() + .collect(), } } fn inbound_protocol(&self) -> ::InboundProtocol { Upgrade { - upgrades: self.handlers.iter() + upgrades: self + .handlers + .iter() .map(|(k, h)| (k.clone(), h.inbound_protocol())) - .collect() + .collect(), } } } @@ -347,7 +398,7 @@ impl ProtocolName for IndexedProtoName { /// The aggregated `InboundOpenInfo`s of supported inbound substream protocols. #[derive(Clone)] pub struct Info { - infos: Vec<(K, I)> + infos: Vec<(K, I)>, } impl Info { @@ -357,7 +408,7 @@ impl Info { pub fn take(&mut self, k: &K) -> Option { if let Some(p) = self.infos.iter().position(|(key, _)| key == k) { - return Some(self.infos.remove(p).1) + return Some(self.infos.remove(p).1); } None } @@ -366,19 +417,21 @@ impl Info { /// Inbound and outbound upgrade for all `ProtocolsHandler`s. #[derive(Clone)] pub struct Upgrade { - upgrades: Vec<(K, H)> + upgrades: Vec<(K, H)>, } impl Upgrade { fn new() -> Self { - Upgrade { upgrades: Vec::new() } + Upgrade { + upgrades: Vec::new(), + } } } impl fmt::Debug for Upgrade where K: fmt::Debug + Eq + Hash, - H: fmt::Debug + H: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Upgrade") @@ -390,13 +443,15 @@ where impl UpgradeInfoSend for Upgrade where H: UpgradeInfoSend, - K: Send + 'static + K: Send + 'static, { type Info = IndexedProtoName; type InfoIter = std::vec::IntoIter; fn protocol_info(&self) -> Self::InfoIter { - self.upgrades.iter().enumerate() + self.upgrades + .iter() + .enumerate() .map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info())) .flatten() .map(|(i, h)| IndexedProtoName(i, h)) @@ -408,21 +463,20 @@ where impl InboundUpgradeSend for Upgrade where H: InboundUpgradeSend, - K: Send + 'static + K: Send + 'static, { type Output = (K, ::Output); - type Error = (K, ::Error); + type Error = (K, ::Error); type Future = BoxFuture<'static, Result>; fn upgrade_inbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future { let IndexedProtoName(index, info) = info; let (key, upgrade) = self.upgrades.remove(index); - upgrade.upgrade_inbound(resource, info) - .map(move |out| { - match out { - Ok(o) => Ok((key, o)), - Err(e) => Err((key, e)) - } + upgrade + .upgrade_inbound(resource, info) + .map(move |out| match out { + Ok(o) => Ok((key, o)), + Err(e) => Err((key, e)), }) .boxed() } @@ -431,21 +485,20 @@ where impl OutboundUpgradeSend for Upgrade where H: OutboundUpgradeSend, - K: Send + 'static + K: Send + 'static, { type Output = (K, ::Output); - type Error = (K, ::Error); + type Error = (K, ::Error); type Future = BoxFuture<'static, Result>; fn upgrade_outbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future { let IndexedProtoName(index, info) = info; let (key, upgrade) = self.upgrades.remove(index); - upgrade.upgrade_outbound(resource, info) - .map(move |out| { - match out { - Ok(o) => Ok((key, o)), - Err(e) => Err((key, e)) - } + upgrade + .upgrade_outbound(resource, info) + .map(move |out| match out { + Ok(o) => Ok((key, o)), + Err(e) => Err((key, e)), }) .boxed() } @@ -455,14 +508,14 @@ where fn uniq_proto_names(iter: I) -> Result<(), DuplicateProtonameError> where I: Iterator, - T: UpgradeInfoSend + T: UpgradeInfoSend, { let mut set = HashSet::new(); for infos in iter { for i in infos.protocol_info() { let v = Vec::from(i.protocol_name()); if set.contains(&v) { - return Err(DuplicateProtonameError(v)) + return Err(DuplicateProtonameError(v)); } else { set.insert(v); } diff --git a/swarm/src/protocols_handler/node_handler.rs b/swarm/src/protocols_handler/node_handler.rs index 72730117cc3..edb383282cd 100644 --- a/swarm/src/protocols_handler/node_handler.rs +++ b/swarm/src/protocols_handler/node_handler.rs @@ -18,29 +18,22 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::SendWrapper; use crate::protocols_handler::{ - KeepAlive, - ProtocolsHandler, - IntoProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, }; +use crate::upgrade::SendWrapper; use futures::prelude::*; use futures::stream::FuturesUnordered; use libp2p_core::{ - Multiaddr, - Connected, connection::{ - ConnectionHandler, - ConnectionHandlerEvent, - IntoConnectionHandler, - Substream, + ConnectionHandler, ConnectionHandlerEvent, IntoConnectionHandler, Substream, SubstreamEndpoint, }, muxing::StreamMuxerBox, - upgrade::{self, InboundUpgradeApply, OutboundUpgradeApply, UpgradeError} + upgrade::{self, InboundUpgradeApply, OutboundUpgradeApply, UpgradeError}, + Connected, Multiaddr, }; use std::{error, fmt, pin::Pin, task::Context, task::Poll, time::Duration}; use wasm_timer::{Delay, Instant}; @@ -55,7 +48,7 @@ pub struct NodeHandlerWrapperBuilder { impl NodeHandlerWrapperBuilder where - TIntoProtoHandler: IntoProtocolsHandler + TIntoProtoHandler: IntoProtocolsHandler, { /// Builds a `NodeHandlerWrapperBuilder`. pub(crate) fn new(handler: TIntoProtoHandler) -> Self { @@ -67,7 +60,7 @@ where pub(crate) fn with_substream_upgrade_protocol_override( mut self, - version: Option + version: Option, ) -> Self { self.substream_upgrade_protocol_override = version; self @@ -84,7 +77,9 @@ where fn into_handler(self, connected: &Connected) -> Self::Handler { NodeHandlerWrapper { - handler: self.handler.into_handler(&connected.peer_id, &connected.endpoint), + handler: self + .handler + .into_handler(&connected.peer_id, &connected.endpoint), negotiating_in: Default::default(), negotiating_out: Default::default(), queued_dial_upgrades: Vec::new(), @@ -105,15 +100,25 @@ where /// The underlying handler. handler: TProtoHandler, /// Futures that upgrade incoming substreams. - negotiating_in: FuturesUnordered, SendWrapper>, - >>, + negotiating_in: FuturesUnordered< + SubstreamUpgrade< + TProtoHandler::InboundOpenInfo, + InboundUpgradeApply< + Substream, + SendWrapper, + >, + >, + >, /// Futures that upgrade outgoing substreams. - negotiating_out: FuturesUnordered, SendWrapper>, - >>, + negotiating_out: FuturesUnordered< + SubstreamUpgrade< + TProtoHandler::OutboundOpenInfo, + OutboundUpgradeApply< + Substream, + SendWrapper, + >, + >, + >, /// For each outbound substream request, how to upgrade it. The first element of the tuple /// is the unique identifier (see `unique_dial_upgrade_id`). queued_dial_upgrades: Vec<(u64, SendWrapper)>, @@ -137,28 +142,43 @@ impl Future for SubstreamUpgrad where Upgrade: Future>> + Unpin, { - type Output = (UserData, Result>); + type Output = ( + UserData, + Result>, + ); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { match self.timeout.poll_unpin(cx) { - Poll::Ready(Ok(_)) => return Poll::Ready(( - self.user_data.take().expect("Future not to be polled again once ready."), - Err(ProtocolsHandlerUpgrErr::Timeout)), - ), - Poll::Ready(Err(_)) => return Poll::Ready(( - self.user_data.take().expect("Future not to be polled again once ready."), - Err(ProtocolsHandlerUpgrErr::Timer)), - ), - Poll::Pending => {}, + Poll::Ready(Ok(_)) => { + return Poll::Ready(( + self.user_data + .take() + .expect("Future not to be polled again once ready."), + Err(ProtocolsHandlerUpgrErr::Timeout), + )) + } + Poll::Ready(Err(_)) => { + return Poll::Ready(( + self.user_data + .take() + .expect("Future not to be polled again once ready."), + Err(ProtocolsHandlerUpgrErr::Timer), + )) + } + Poll::Pending => {} } match self.upgrade.poll_unpin(cx) { Poll::Ready(Ok(upgrade)) => Poll::Ready(( - self.user_data.take().expect("Future not to be polled again once ready."), + self.user_data + .take() + .expect("Future not to be polled again once ready."), Ok(upgrade), )), Poll::Ready(Err(err)) => Poll::Ready(( - self.user_data.take().expect("Future not to be polled again once ready."), + self.user_data + .take() + .expect("Future not to be polled again once ready."), Err(ProtocolsHandlerUpgrErr::Upgrade(err)), )), Poll::Pending => Poll::Pending, @@ -166,7 +186,6 @@ where } } - /// The options for a planned connection & handler shutdown. /// /// A shutdown is planned anew based on the the return value of @@ -182,7 +201,7 @@ enum Shutdown { /// A shut down is planned as soon as possible. Asap, /// A shut down is planned for when a `Delay` has elapsed. - Later(Delay, Instant) + Later(Delay, Instant), } /// Error generated by the `NodeHandlerWrapper`. @@ -202,20 +221,21 @@ impl From for NodeHandlerWrapperError { impl fmt::Display for NodeHandlerWrapperError where - TErr: fmt::Display + TErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { NodeHandlerWrapperError::Handler(err) => write!(f, "{}", err), - NodeHandlerWrapperError::KeepAliveTimeout => - write!(f, "Connection closed due to expired keep-alive timeout."), + NodeHandlerWrapperError::KeepAliveTimeout => { + write!(f, "Connection closed due to expired keep-alive timeout.") + } } } } impl error::Error for NodeHandlerWrapperError where - TErr: error::Error + 'static + TErr: error::Error + 'static, { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { @@ -272,7 +292,11 @@ where let mut version = upgrade::Version::default(); if let Some(v) = self.substream_upgrade_protocol_override { if v != version { - log::debug!("Substream upgrade protocol override: {:?} -> {:?}", version, v); + log::debug!( + "Substream upgrade protocol override: {:?} -> {:?}", + version, + v + ); version = v; } } @@ -295,19 +319,25 @@ where self.handler.inject_address_change(new_address); } - fn poll(&mut self, cx: &mut Context<'_>) -> Poll< - Result, Self::Error> - > { + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> + { while let Poll::Ready(Some((user_data, res))) = self.negotiating_in.poll_next_unpin(cx) { match res { - Ok(upgrade) => self.handler.inject_fully_negotiated_inbound(upgrade, user_data), + Ok(upgrade) => self + .handler + .inject_fully_negotiated_inbound(upgrade, user_data), Err(err) => self.handler.inject_listen_upgrade_error(user_data, err), } } while let Poll::Ready(Some((user_data, res))) = self.negotiating_out.poll_next_unpin(cx) { match res { - Ok(upgrade) => self.handler.inject_fully_negotiated_outbound(upgrade, user_data), + Ok(upgrade) => self + .handler + .inject_fully_negotiated_outbound(upgrade, user_data), Err(err) => self.handler.inject_dial_upgrade_error(user_data, err), } } @@ -319,14 +349,15 @@ where // Ask the handler whether it wants the connection (and the handler itself) // to be kept alive, which determines the planned shutdown, if any. match (&mut self.shutdown, self.handler.connection_keep_alive()) { - (Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => + (Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => { if *deadline != t { *deadline = t; timer.reset_at(t) - }, + } + } (_, KeepAlive::Until(t)) => self.shutdown = Shutdown::Later(Delay::new_at(t), t), (_, KeepAlive::No) => self.shutdown = Shutdown::Asap, - (_, KeepAlive::Yes) => self.shutdown = Shutdown::None + (_, KeepAlive::Yes) => self.shutdown = Shutdown::None, }; match poll_result { @@ -339,9 +370,9 @@ where self.unique_dial_upgrade_id += 1; let (upgrade, info) = protocol.into_upgrade(); self.queued_dial_upgrades.push((id, SendWrapper(upgrade))); - return Poll::Ready(Ok( - ConnectionHandlerEvent::OutboundSubstreamRequest((id, info, timeout)), - )); + return Poll::Ready(Ok(ConnectionHandlerEvent::OutboundSubstreamRequest(( + id, info, timeout, + )))); } Poll::Ready(ProtocolsHandlerEvent::Close(err)) => return Poll::Ready(Err(err.into())), Poll::Pending => (), @@ -351,12 +382,16 @@ where // As long as we're still negotiating substreams, shutdown is always postponed. if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() { match self.shutdown { - Shutdown::None => {}, - Shutdown::Asap => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)), + Shutdown::None => {} + Shutdown::Asap => { + return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)) + } Shutdown::Later(ref mut delay, _) => match Future::poll(Pin::new(delay), cx) { - Poll::Ready(_) => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)), + Poll::Ready(_) => { + return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)) + } Poll::Pending => {} - } + }, } } diff --git a/swarm/src/protocols_handler/one_shot.rs b/swarm/src/protocols_handler/one_shot.rs index d19dd89d39e..01a2951efc5 100644 --- a/swarm/src/protocols_handler/one_shot.rs +++ b/swarm/src/protocols_handler/one_shot.rs @@ -18,14 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, - SubstreamProtocol + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use smallvec::SmallVec; use std::{error, fmt::Debug, task::Context, task::Poll, time::Duration}; @@ -53,8 +49,7 @@ where config: OneShotHandlerConfig, } -impl - OneShotHandler +impl OneShotHandler where TOutbound: OutboundUpgradeSend, { @@ -102,8 +97,7 @@ where } } -impl Default - for OneShotHandler +impl Default for OneShotHandler where TOutbound: OutboundUpgradeSend, TInbound: InboundUpgradeSend + Default, @@ -111,7 +105,7 @@ where fn default() -> Self { OneShotHandler::new( SubstreamProtocol::new(Default::default(), ()), - OneShotHandlerConfig::default() + OneShotHandlerConfig::default(), ) } } @@ -128,9 +122,7 @@ where { type InEvent = TOutbound; type OutEvent = TEvent; - type Error = ProtocolsHandlerUpgrErr< - ::Error, - >; + type Error = ProtocolsHandlerUpgrErr<::Error>; type InboundProtocol = TInbound; type OutboundProtocol = TOutbound; type OutboundOpenInfo = (); @@ -143,7 +135,7 @@ where fn inject_fully_negotiated_inbound( &mut self, out: ::Output, - (): Self::InboundOpenInfo + (): Self::InboundOpenInfo, ) { // If we're shutting down the connection for inactivity, reset the timeout. if !self.keep_alive.is_yes() { @@ -169,9 +161,7 @@ where fn inject_dial_upgrade_error( &mut self, _info: Self::OutboundOpenInfo, - error: ProtocolsHandlerUpgrErr< - ::Error, - >, + error: ProtocolsHandlerUpgrErr<::Error>, ) { if self.pending_error.is_none() { self.pending_error = Some(error); @@ -186,16 +176,19 @@ where &mut self, _: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { if let Some(err) = self.pending_error.take() { - return Poll::Ready(ProtocolsHandlerEvent::Close(err)) + return Poll::Ready(ProtocolsHandlerEvent::Close(err)); } if !self.events_out.is_empty() { - return Poll::Ready(ProtocolsHandlerEvent::Custom( - self.events_out.remove(0) - )); + return Poll::Ready(ProtocolsHandlerEvent::Custom(self.events_out.remove(0))); } else { self.events_out.shrink_to_fit(); } @@ -204,12 +197,10 @@ where if self.dial_negotiated < self.config.max_dial_negotiated { self.dial_negotiated += 1; let upgrade = self.dial_queue.remove(0); - return Poll::Ready( - ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(upgrade, ()) - .with_timeout(self.config.outbound_substream_timeout) - }, - ); + return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(upgrade, ()) + .with_timeout(self.config.outbound_substream_timeout), + }); } } else { self.dial_queue.shrink_to_fit(); @@ -256,18 +247,19 @@ mod tests { #[test] fn do_not_keep_idle_connection_alive() { let mut handler: OneShotHandler<_, DeniedUpgrade, Void> = OneShotHandler::new( - SubstreamProtocol::new(DeniedUpgrade{}, ()), + SubstreamProtocol::new(DeniedUpgrade {}, ()), Default::default(), ); - block_on(poll_fn(|cx| { - loop { - if let Poll::Pending = handler.poll(cx) { - return Poll::Ready(()) - } + block_on(poll_fn(|cx| loop { + if let Poll::Pending = handler.poll(cx) { + return Poll::Ready(()); } })); - assert!(matches!(handler.connection_keep_alive(), KeepAlive::Until(_))); + assert!(matches!( + handler.connection_keep_alive(), + KeepAlive::Until(_) + )); } } diff --git a/swarm/src/protocols_handler/select.rs b/swarm/src/protocols_handler/select.rs index d8005eef79d..b5891c25d1f 100644 --- a/swarm/src/protocols_handler/select.rs +++ b/swarm/src/protocols_handler/select.rs @@ -18,22 +18,16 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::{SendWrapper, InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - IntoProtocolsHandler, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, + IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper}; use libp2p_core::{ - ConnectedPoint, - Multiaddr, - PeerId, either::{EitherError, EitherOutput}, - upgrade::{EitherUpgrade, SelectUpgrade, UpgradeError, NegotiationError, ProtocolError} + upgrade::{EitherUpgrade, NegotiationError, ProtocolError, SelectUpgrade, UpgradeError}, + ConnectedPoint, Multiaddr, PeerId, }; use std::{cmp, task::Context, task::Poll}; @@ -49,10 +43,7 @@ pub struct IntoProtocolsHandlerSelect { impl IntoProtocolsHandlerSelect { /// Builds a `IntoProtocolsHandlerSelect`. pub(crate) fn new(proto1: TProto1, proto2: TProto2) -> Self { - IntoProtocolsHandlerSelect { - proto1, - proto2, - } + IntoProtocolsHandlerSelect { proto1, proto2 } } } @@ -63,7 +54,11 @@ where { type Handler = ProtocolsHandlerSelect; - fn into_handler(self, remote_peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler { + fn into_handler( + self, + remote_peer_id: &PeerId, + connected_point: &ConnectedPoint, + ) -> Self::Handler { ProtocolsHandlerSelect { proto1: self.proto1.into_handler(remote_peer_id, connected_point), proto2: self.proto2.into_handler(remote_peer_id, connected_point), @@ -71,7 +66,10 @@ where } fn inbound_protocol(&self) -> ::InboundProtocol { - SelectUpgrade::new(SendWrapper(self.proto1.inbound_protocol()), SendWrapper(self.proto2.inbound_protocol())) + SelectUpgrade::new( + SendWrapper(self.proto1.inbound_protocol()), + SendWrapper(self.proto2.inbound_protocol()), + ) } } @@ -87,10 +85,7 @@ pub struct ProtocolsHandlerSelect { impl ProtocolsHandlerSelect { /// Builds a `ProtocolsHandlerSelect`. pub(crate) fn new(proto1: TProto1, proto2: TProto2) -> Self { - ProtocolsHandlerSelect { - proto1, - proto2, - } + ProtocolsHandlerSelect { proto1, proto2 } } } @@ -102,8 +97,14 @@ where type InEvent = EitherOutput; type OutEvent = EitherOutput; type Error = EitherError; - type InboundProtocol = SelectUpgrade::InboundProtocol>, SendWrapper<::InboundProtocol>>; - type OutboundProtocol = EitherUpgrade, SendWrapper>; + type InboundProtocol = SelectUpgrade< + SendWrapper<::InboundProtocol>, + SendWrapper<::InboundProtocol>, + >; + type OutboundProtocol = EitherUpgrade< + SendWrapper, + SendWrapper, + >; type OutboundOpenInfo = EitherOutput; type InboundOpenInfo = (TProto1::InboundOpenInfo, TProto2::InboundOpenInfo); @@ -117,25 +118,39 @@ where SubstreamProtocol::new(choice, (i1, i2)).with_timeout(timeout) } - fn inject_fully_negotiated_outbound(&mut self, protocol: ::Output, endpoint: Self::OutboundOpenInfo) { + fn inject_fully_negotiated_outbound( + &mut self, + protocol: ::Output, + endpoint: Self::OutboundOpenInfo, + ) { match (protocol, endpoint) { - (EitherOutput::First(protocol), EitherOutput::First(info)) => - self.proto1.inject_fully_negotiated_outbound(protocol, info), - (EitherOutput::Second(protocol), EitherOutput::Second(info)) => - self.proto2.inject_fully_negotiated_outbound(protocol, info), - (EitherOutput::First(_), EitherOutput::Second(_)) => - panic!("wrong API usage: the protocol doesn't match the upgrade info"), - (EitherOutput::Second(_), EitherOutput::First(_)) => + (EitherOutput::First(protocol), EitherOutput::First(info)) => { + self.proto1.inject_fully_negotiated_outbound(protocol, info) + } + (EitherOutput::Second(protocol), EitherOutput::Second(info)) => { + self.proto2.inject_fully_negotiated_outbound(protocol, info) + } + (EitherOutput::First(_), EitherOutput::Second(_)) => { panic!("wrong API usage: the protocol doesn't match the upgrade info") + } + (EitherOutput::Second(_), EitherOutput::First(_)) => { + panic!("wrong API usage: the protocol doesn't match the upgrade info") + } } } - fn inject_fully_negotiated_inbound(&mut self, protocol: ::Output, (i1, i2): Self::InboundOpenInfo) { + fn inject_fully_negotiated_inbound( + &mut self, + protocol: ::Output, + (i1, i2): Self::InboundOpenInfo, + ) { match protocol { - EitherOutput::First(protocol) => - self.proto1.inject_fully_negotiated_inbound(protocol, i1), - EitherOutput::Second(protocol) => + EitherOutput::First(protocol) => { + self.proto1.inject_fully_negotiated_inbound(protocol, i1) + } + EitherOutput::Second(protocol) => { self.proto2.inject_fully_negotiated_inbound(protocol, i2) + } } } @@ -151,60 +166,108 @@ where self.proto2.inject_address_change(new_address) } - fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { match (info, error) { - (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timer) => { - self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer) - }, - (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timeout) => { - self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout) - }, - (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) => { - self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) - }, - (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(err)))) => { - self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err))) - }, - (EitherOutput::First(_), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(_)))) => { + (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timer) => self + .proto1 + .inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer), + (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timeout) => self + .proto1 + .inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout), + ( + EitherOutput::First(info), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)), + ) => self.proto1.inject_dial_upgrade_error( + info, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)), + ), + ( + EitherOutput::First(info), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(err))), + ) => self.proto1.inject_dial_upgrade_error( + info, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)), + ), + ( + EitherOutput::First(_), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(_))), + ) => { panic!("Wrong API usage; the upgrade error doesn't match the outbound open info"); - }, - (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timeout) => { - self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout) - }, - (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timer) => { - self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer) - }, - (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) => { - self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) - }, - (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(err)))) => { - self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err))) - }, - (EitherOutput::Second(_), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(_)))) => { + } + (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timeout) => self + .proto2 + .inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout), + (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timer) => self + .proto2 + .inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer), + ( + EitherOutput::Second(info), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)), + ) => self.proto2.inject_dial_upgrade_error( + info, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)), + ), + ( + EitherOutput::Second(info), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(err))), + ) => self.proto2.inject_dial_upgrade_error( + info, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)), + ), + ( + EitherOutput::Second(_), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(_))), + ) => { panic!("Wrong API usage; the upgrade error doesn't match the outbound open info"); - }, + } } } - fn inject_listen_upgrade_error(&mut self, (i1, i2): Self::InboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_listen_upgrade_error( + &mut self, + (i1, i2): Self::InboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { match error { ProtocolsHandlerUpgrErr::Timer => { - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timer); - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timer) + self.proto1 + .inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timer); + self.proto2 + .inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timer) } ProtocolsHandlerUpgrErr::Timeout => { - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timeout); - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timeout) + self.proto1 + .inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timeout); + self.proto2 + .inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timeout) } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed))); - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed))); + self.proto1.inject_listen_upgrade_error( + i1, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::Failed, + )), + ); + self.proto2.inject_listen_upgrade_error( + i2, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::Failed, + )), + ); } - ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) => { + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::ProtocolError(e), + )) => { let (e1, e2); match e { ProtocolError::IoError(e) => { - e1 = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into())); + e1 = NegotiationError::ProtocolError(ProtocolError::IoError( + e.kind().into(), + )); e2 = NegotiationError::ProtocolError(ProtocolError::IoError(e)) } ProtocolError::InvalidMessage => { @@ -220,55 +283,80 @@ where e2 = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols) } } - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e1))); - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e2))) + self.proto1.inject_listen_upgrade_error( + i1, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e1)), + ); + self.proto2.inject_listen_upgrade_error( + i2, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e2)), + ) } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(e))) => { - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e))) + self.proto1.inject_listen_upgrade_error( + i1, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)), + ) } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(e))) => { - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e))) + self.proto2.inject_listen_upgrade_error( + i2, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)), + ) } } } fn connection_keep_alive(&self) -> KeepAlive { - cmp::max(self.proto1.connection_keep_alive(), self.proto2.connection_keep_alive()) + cmp::max( + self.proto1.connection_keep_alive(), + self.proto2.connection_keep_alive(), + ) } - fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, + > { match self.proto1.poll(cx) { Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => { return Poll::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::First(event))); - }, + } Poll::Ready(ProtocolsHandlerEvent::Close(event)) => { return Poll::Ready(ProtocolsHandlerEvent::Close(EitherError::A(event))); - }, + } Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) => { return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: protocol .map_upgrade(|u| EitherUpgrade::A(SendWrapper(u))) - .map_info(EitherOutput::First) + .map_info(EitherOutput::First), }); - }, - Poll::Pending => () + } + Poll::Pending => (), }; match self.proto2.poll(cx) { Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => { return Poll::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::Second(event))); - }, + } Poll::Ready(ProtocolsHandlerEvent::Close(event)) => { return Poll::Ready(ProtocolsHandlerEvent::Close(EitherError::B(event))); - }, + } Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) => { return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: protocol .map_upgrade(|u| EitherUpgrade::B(SendWrapper(u))) - .map_info(EitherOutput::Second) + .map_info(EitherOutput::Second), }); - }, - Poll::Pending => () + } + Poll::Pending => (), }; Poll::Pending diff --git a/swarm/src/registry.rs b/swarm/src/registry.rs index 310639296d8..5819ecf1e4e 100644 --- a/swarm/src/registry.rs +++ b/swarm/src/registry.rs @@ -20,8 +20,8 @@ use libp2p_core::Multiaddr; use smallvec::SmallVec; -use std::{collections::VecDeque, cmp::Ordering, num::NonZeroUsize}; use std::ops::{Add, Sub}; +use std::{cmp::Ordering, collections::VecDeque, num::NonZeroUsize}; /// A ranked collection of [`Multiaddr`] values. /// @@ -77,9 +77,7 @@ struct Report { impl AddressRecord { fn new(addr: Multiaddr, score: AddressScore) -> Self { - AddressRecord { - addr, score, - } + AddressRecord { addr, score } } } @@ -117,14 +115,10 @@ impl Ord for AddressScore { fn cmp(&self, other: &AddressScore) -> Ordering { // Semantics of cardinal numbers with a single infinite cardinal. match (self, other) { - (AddressScore::Infinite, AddressScore::Infinite) => - Ordering::Equal, - (AddressScore::Infinite, AddressScore::Finite(_)) => - Ordering::Greater, - (AddressScore::Finite(_), AddressScore::Infinite) => - Ordering::Less, - (AddressScore::Finite(a), AddressScore::Finite(b)) => - a.cmp(b), + (AddressScore::Infinite, AddressScore::Infinite) => Ordering::Equal, + (AddressScore::Infinite, AddressScore::Finite(_)) => Ordering::Greater, + (AddressScore::Finite(_), AddressScore::Infinite) => Ordering::Less, + (AddressScore::Finite(a), AddressScore::Finite(b)) => a.cmp(b), } } } @@ -135,14 +129,12 @@ impl Add for AddressScore { fn add(self, rhs: AddressScore) -> Self::Output { // Semantics of cardinal numbers with a single infinite cardinal. match (self, rhs) { - (AddressScore::Infinite, AddressScore::Infinite) => - AddressScore::Infinite, - (AddressScore::Infinite, AddressScore::Finite(_)) => - AddressScore::Infinite, - (AddressScore::Finite(_), AddressScore::Infinite) => - AddressScore::Infinite, - (AddressScore::Finite(a), AddressScore::Finite(b)) => + (AddressScore::Infinite, AddressScore::Infinite) => AddressScore::Infinite, + (AddressScore::Infinite, AddressScore::Finite(_)) => AddressScore::Infinite, + (AddressScore::Finite(_), AddressScore::Infinite) => AddressScore::Infinite, + (AddressScore::Finite(a), AddressScore::Finite(b)) => { AddressScore::Finite(a.saturating_add(b)) + } } } } @@ -154,7 +146,7 @@ impl Sub for AddressScore { // Semantics of cardinal numbers with a single infinite cardinal. match self { AddressScore::Infinite => AddressScore::Infinite, - AddressScore::Finite(score) => AddressScore::Finite(score.saturating_sub(rhs)) + AddressScore::Finite(score) => AddressScore::Finite(score.saturating_sub(rhs)), } } } @@ -168,8 +160,12 @@ impl Default for Addresses { /// The result of adding an address to an ordered list of /// addresses with associated scores. pub enum AddAddressResult { - Inserted { expired: SmallVec<[AddressRecord; 8]> }, - Updated { expired: SmallVec<[AddressRecord; 8]> }, + Inserted { + expired: SmallVec<[AddressRecord; 8]>, + }, + Updated { + expired: SmallVec<[AddressRecord; 8]>, + }, } impl Addresses { @@ -207,7 +203,12 @@ impl Addresses { // Remove addresses that have a score of 0. let mut expired = SmallVec::new(); - while self.registry.last().map(|e| e.score.is_zero()).unwrap_or(false) { + while self + .registry + .last() + .map(|e| e.score.is_zero()) + .unwrap_or(false) + { if let Some(addr) = self.registry.pop() { expired.push(addr); } @@ -215,7 +216,10 @@ impl Addresses { // If the address score is finite, remember this report. if let AddressScore::Finite(score) = score { - self.reports.push_back(Report { addr: addr.clone(), score }); + self.reports.push_back(Report { + addr: addr.clone(), + score, + }); } // If the address is already in the collection, increase its score. @@ -223,7 +227,7 @@ impl Addresses { if r.addr == addr { r.score = r.score + score; isort(&mut self.registry); - return AddAddressResult::Updated { expired } + return AddAddressResult::Updated { expired }; } } @@ -249,14 +253,19 @@ impl Addresses { /// /// The iteration is ordered by descending score. pub fn iter(&self) -> AddressIter<'_> { - AddressIter { items: &self.registry, offset: 0 } + AddressIter { + items: &self.registry, + offset: 0, + } } /// Return an iterator over all [`Multiaddr`] values. /// /// The iteration is ordered by descending score. pub fn into_iter(self) -> AddressIntoIter { - AddressIntoIter { items: self.registry } + AddressIntoIter { + items: self.registry, + } } } @@ -264,7 +273,7 @@ impl Addresses { #[derive(Clone)] pub struct AddressIter<'a> { items: &'a [AddressRecord], - offset: usize + offset: usize, } impl<'a> Iterator for AddressIter<'a> { @@ -272,7 +281,7 @@ impl<'a> Iterator for AddressIter<'a> { fn next(&mut self) -> Option { if self.offset == self.items.len() { - return None + return None; } let item = &self.items[self.offset]; self.offset += 1; @@ -314,10 +323,10 @@ impl ExactSizeIterator for AddressIntoIter {} // Reverse insertion sort. fn isort(xs: &mut [AddressRecord]) { - for i in 1 .. xs.len() { - for j in (1 ..= i).rev() { + for i in 1..xs.len() { + for j in (1..=i).rev() { if xs[j].score <= xs[j - 1].score { - break + break; } xs.swap(j, j - 1) } @@ -326,15 +335,16 @@ fn isort(xs: &mut [AddressRecord]) { #[cfg(test)] mod tests { + use super::*; use libp2p_core::multiaddr::{Multiaddr, Protocol}; use quickcheck::*; use rand::Rng; - use std::num::{NonZeroUsize, NonZeroU8}; - use super::*; + use std::num::{NonZeroU8, NonZeroUsize}; impl Arbitrary for AddressScore { fn arbitrary(g: &mut G) -> AddressScore { - if g.gen_range(0, 10) == 0 { // ~10% "Infinitely" scored addresses + if g.gen_range(0, 10) == 0 { + // ~10% "Infinitely" scored addresses AddressScore::Infinite } else { AddressScore::Finite(g.gen()) @@ -353,13 +363,14 @@ mod tests { #[test] fn isort_sorts() { fn property(xs: Vec) { - let mut xs = xs.into_iter() + let mut xs = xs + .into_iter() .map(|score| AddressRecord::new(Multiaddr::empty(), score)) .collect::>(); isort(&mut xs); - for i in 1 .. xs.len() { + for i in 1..xs.len() { assert!(xs[i - 1].score >= xs[i].score) } } @@ -371,7 +382,7 @@ mod tests { fn score_retention() { fn prop(first: AddressRecord, other: AddressRecord) -> TestResult { if first.addr == other.addr { - return TestResult::discard() + return TestResult::discard(); } let mut addresses = Addresses::default(); @@ -383,7 +394,7 @@ mod tests { // Add another address so often that the initial report of // the first address may be purged and, since it was the // only report, the address removed. - for _ in 0 .. addresses.limit.get() + 1 { + for _ in 0..addresses.limit.get() + 1 { addresses.add(other.addr.clone(), other.score); } @@ -398,7 +409,7 @@ mod tests { TestResult::passed() } - quickcheck(prop as fn(_,_) -> _); + quickcheck(prop as fn(_, _) -> _); } #[test] @@ -412,16 +423,22 @@ mod tests { } // Count the finitely scored addresses. - let num_finite = addresses.iter().filter(|r| match r { - AddressRecord { score: AddressScore::Finite(_), .. } => true, - _ => false, - }).count(); + let num_finite = addresses + .iter() + .filter(|r| match r { + AddressRecord { + score: AddressScore::Finite(_), + .. + } => true, + _ => false, + }) + .count(); // Check against the limit. assert!(num_finite <= limit.get() as usize); } - quickcheck(prop as fn(_,_)); + quickcheck(prop as fn(_, _)); } #[test] @@ -438,16 +455,16 @@ mod tests { // Check that each address in the registry has the expected score. for r in &addresses.registry { - let expected_score = records.iter().fold( - None::, |sum, rec| - if &rec.addr == &r.addr { - sum.map_or(Some(rec.score), |s| Some(s + rec.score)) - } else { - sum - }); + let expected_score = records.iter().fold(None::, |sum, rec| { + if &rec.addr == &r.addr { + sum.map_or(Some(rec.score), |s| Some(s + rec.score)) + } else { + sum + } + }); if Some(r.score) != expected_score { - return false + return false; } } diff --git a/swarm/src/test.rs b/swarm/src/test.rs index 4ae647d38db..5cb05d7baf3 100644 --- a/swarm/src/test.rs +++ b/swarm/src/test.rs @@ -19,17 +19,13 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - NetworkBehaviour, - NetworkBehaviourAction, + IntoProtocolsHandler, NetworkBehaviour, NetworkBehaviourAction, PollParameters, ProtocolsHandler, - IntoProtocolsHandler, - PollParameters }; use libp2p_core::{ - ConnectedPoint, - PeerId, connection::{ConnectionId, ListenerId}, multiaddr::Multiaddr, + ConnectedPoint, PeerId, }; use std::collections::HashMap; use std::task::{Context, Poll}; @@ -54,7 +50,7 @@ where impl MockBehaviour where - THandler: ProtocolsHandler + THandler: ProtocolsHandler, { pub fn new(handler_proto: THandler) -> Self { MockBehaviour { @@ -82,12 +78,13 @@ where self.addresses.get(p).map_or(Vec::new(), |v| v.clone()) } - fn inject_event(&mut self, _: PeerId, _: ConnectionId, _: THandler::OutEvent) { - } + fn inject_event(&mut self, _: PeerId, _: ConnectionId, _: THandler::OutEvent) {} - fn poll(&mut self, _: &mut Context, _: &mut impl PollParameters) -> - Poll> - { + fn poll( + &mut self, + _: &mut Context, + _: &mut impl PollParameters, + ) -> Poll> { self.next_action.take().map_or(Poll::Pending, Poll::Ready) } } @@ -106,7 +103,11 @@ where pub inject_disconnected: Vec, pub inject_connection_established: Vec<(PeerId, ConnectionId, ConnectedPoint)>, pub inject_connection_closed: Vec<(PeerId, ConnectionId, ConnectedPoint)>, - pub inject_event: Vec<(PeerId, ConnectionId, <::Handler as ProtocolsHandler>::OutEvent)>, + pub inject_event: Vec<( + PeerId, + ConnectionId, + <::Handler as ProtocolsHandler>::OutEvent, + )>, pub inject_addr_reach_failure: Vec<(Option, Multiaddr)>, pub inject_dial_failure: Vec, pub inject_new_listener: Vec, @@ -121,7 +122,7 @@ where impl CallTraceBehaviour where - TInner: NetworkBehaviour + TInner: NetworkBehaviour, { pub fn new(inner: TInner) -> Self { Self { @@ -162,13 +163,16 @@ where self.poll = 0; } - pub fn inner(&mut self) -> &mut TInner { &mut self.inner } + pub fn inner(&mut self) -> &mut TInner { + &mut self.inner + } } impl NetworkBehaviour for CallTraceBehaviour where TInner: NetworkBehaviour, - <::Handler as ProtocolsHandler>::OutEvent: Clone, + <::Handler as ProtocolsHandler>::OutEvent: + Clone, { type ProtocolsHandler = TInner::ProtocolsHandler; type OutEvent = TInner::OutEvent; @@ -188,7 +192,8 @@ where } fn inject_connection_established(&mut self, p: &PeerId, c: &ConnectionId, e: &ConnectedPoint) { - self.inject_connection_established.push((p.clone(), c.clone(), e.clone())); + self.inject_connection_established + .push((p.clone(), c.clone(), e.clone())); self.inner.inject_connection_established(p, c, e); } @@ -198,16 +203,27 @@ where } fn inject_connection_closed(&mut self, p: &PeerId, c: &ConnectionId, e: &ConnectedPoint) { - self.inject_connection_closed.push((p.clone(), c.clone(), e.clone())); + self.inject_connection_closed + .push((p.clone(), c.clone(), e.clone())); self.inner.inject_connection_closed(p, c, e); } - fn inject_event(&mut self, p: PeerId, c: ConnectionId, e: <::Handler as ProtocolsHandler>::OutEvent) { + fn inject_event( + &mut self, + p: PeerId, + c: ConnectionId, + e: <::Handler as ProtocolsHandler>::OutEvent, + ) { self.inject_event.push((p.clone(), c.clone(), e.clone())); self.inner.inject_event(p, c, e); } - fn inject_addr_reach_failure(&mut self, p: Option<&PeerId>, a: &Multiaddr, e: &dyn std::error::Error) { + fn inject_addr_reach_failure( + &mut self, + p: Option<&PeerId>, + a: &Multiaddr, + e: &dyn std::error::Error, + ) { self.inject_addr_reach_failure.push((p.cloned(), a.clone())); self.inner.inject_addr_reach_failure(p, a, e); } diff --git a/swarm/src/toggle.rs b/swarm/src/toggle.rs index d986f00fb01..5a86a4824ed 100644 --- a/swarm/src/toggle.rs +++ b/swarm/src/toggle.rs @@ -18,24 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{NetworkBehaviour, NetworkBehaviourAction, NetworkBehaviourEventProcess, PollParameters}; -use crate::upgrade::{SendWrapper, InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, - IntoProtocolsHandler + IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, +}; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper}; +use crate::{ + NetworkBehaviour, NetworkBehaviourAction, NetworkBehaviourEventProcess, PollParameters, }; use either::Either; use libp2p_core::{ - ConnectedPoint, - PeerId, - Multiaddr, connection::{ConnectionId, ListenerId}, either::{EitherError, EitherOutput}, - upgrade::{DeniedUpgrade, EitherUpgrade} + upgrade::{DeniedUpgrade, EitherUpgrade}, + ConnectedPoint, Multiaddr, PeerId, }; use std::{error, task::Context, task::Poll}; @@ -71,19 +67,22 @@ impl From> for Toggle { impl NetworkBehaviour for Toggle where - TBehaviour: NetworkBehaviour + TBehaviour: NetworkBehaviour, { type ProtocolsHandler = ToggleIntoProtoHandler; type OutEvent = TBehaviour::OutEvent; fn new_handler(&mut self) -> Self::ProtocolsHandler { ToggleIntoProtoHandler { - inner: self.inner.as_mut().map(|i| i.new_handler()) + inner: self.inner.as_mut().map(|i| i.new_handler()), } } fn addresses_of_peer(&mut self, peer_id: &PeerId) -> Vec { - self.inner.as_mut().map(|b| b.addresses_of_peer(peer_id)).unwrap_or_else(Vec::new) + self.inner + .as_mut() + .map(|b| b.addresses_of_peer(peer_id)) + .unwrap_or_else(Vec::new) } fn inject_connected(&mut self, peer_id: &PeerId) { @@ -98,19 +97,35 @@ where } } - fn inject_connection_established(&mut self, peer_id: &PeerId, connection: &ConnectionId, endpoint: &ConnectedPoint) { + fn inject_connection_established( + &mut self, + peer_id: &PeerId, + connection: &ConnectionId, + endpoint: &ConnectedPoint, + ) { if let Some(inner) = self.inner.as_mut() { inner.inject_connection_established(peer_id, connection, endpoint) } } - fn inject_connection_closed(&mut self, peer_id: &PeerId, connection: &ConnectionId, endpoint: &ConnectedPoint) { + fn inject_connection_closed( + &mut self, + peer_id: &PeerId, + connection: &ConnectionId, + endpoint: &ConnectedPoint, + ) { if let Some(inner) = self.inner.as_mut() { inner.inject_connection_closed(peer_id, connection, endpoint) } } - fn inject_address_change(&mut self, peer_id: &PeerId, connection: &ConnectionId, old: &ConnectedPoint, new: &ConnectedPoint) { + fn inject_address_change( + &mut self, + peer_id: &PeerId, + connection: &ConnectionId, + old: &ConnectedPoint, + new: &ConnectedPoint, + ) { if let Some(inner) = self.inner.as_mut() { inner.inject_address_change(peer_id, connection, old, new) } @@ -120,14 +135,19 @@ where &mut self, peer_id: PeerId, connection: ConnectionId, - event: <::Handler as ProtocolsHandler>::OutEvent + event: <::Handler as ProtocolsHandler>::OutEvent, ) { if let Some(inner) = self.inner.as_mut() { inner.inject_event(peer_id, connection, event); } } - fn inject_addr_reach_failure(&mut self, peer_id: Option<&PeerId>, addr: &Multiaddr, error: &dyn error::Error) { + fn inject_addr_reach_failure( + &mut self, + peer_id: Option<&PeerId>, + addr: &Multiaddr, + error: &dyn error::Error, + ) { if let Some(inner) = self.inner.as_mut() { inner.inject_addr_reach_failure(peer_id, addr, error) } @@ -194,7 +214,7 @@ where impl NetworkBehaviourEventProcess for Toggle where - TBehaviour: NetworkBehaviourEventProcess + TBehaviour: NetworkBehaviourEventProcess, { fn inject_event(&mut self, event: TEvent) { if let Some(inner) = self.inner.as_mut() { @@ -210,13 +230,19 @@ pub struct ToggleIntoProtoHandler { impl IntoProtocolsHandler for ToggleIntoProtoHandler where - TInner: IntoProtocolsHandler + TInner: IntoProtocolsHandler, { type Handler = ToggleProtoHandler; - fn into_handler(self, remote_peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler { + fn into_handler( + self, + remote_peer_id: &PeerId, + connected_point: &ConnectedPoint, + ) -> Self::Handler { ToggleProtoHandler { - inner: self.inner.map(|h| h.into_handler(remote_peer_id, connected_point)) + inner: self + .inner + .map(|h| h.into_handler(remote_peer_id, connected_point)), } } @@ -241,25 +267,30 @@ where type InEvent = TInner::InEvent; type OutEvent = TInner::OutEvent; type Error = TInner::Error; - type InboundProtocol = EitherUpgrade, SendWrapper>; + type InboundProtocol = + EitherUpgrade, SendWrapper>; type OutboundProtocol = TInner::OutboundProtocol; type OutboundOpenInfo = TInner::OutboundOpenInfo; type InboundOpenInfo = Either; fn listen_protocol(&self) -> SubstreamProtocol { if let Some(inner) = self.inner.as_ref() { - inner.listen_protocol() + inner + .listen_protocol() .map_upgrade(|u| EitherUpgrade::A(SendWrapper(u))) .map_info(Either::Left) } else { - SubstreamProtocol::new(EitherUpgrade::B(SendWrapper(DeniedUpgrade)), Either::Right(())) + SubstreamProtocol::new( + EitherUpgrade::B(SendWrapper(DeniedUpgrade)), + Either::Right(()), + ) } } fn inject_fully_negotiated_inbound( &mut self, out: ::Output, - info: Self::InboundOpenInfo + info: Self::InboundOpenInfo, ) { let out = match out { EitherOutput::First(out) => out, @@ -267,7 +298,8 @@ where }; if let Either::Left(info) = info { - self.inner.as_mut() + self.inner + .as_mut() .expect("Can't receive an inbound substream if disabled; QED") .inject_fully_negotiated_inbound(out, info) } else { @@ -278,14 +310,18 @@ where fn inject_fully_negotiated_outbound( &mut self, out: ::Output, - info: Self::OutboundOpenInfo + info: Self::OutboundOpenInfo, ) { - self.inner.as_mut().expect("Can't receive an outbound substream if disabled; QED") + self.inner + .as_mut() + .expect("Can't receive an outbound substream if disabled; QED") .inject_fully_negotiated_outbound(out, info) } fn inject_event(&mut self, event: Self::InEvent) { - self.inner.as_mut().expect("Can't receive events if disabled; QED") + self.inner + .as_mut() + .expect("Can't receive events if disabled; QED") .inject_event(event) } @@ -295,12 +331,22 @@ where } } - fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, err: ProtocolsHandlerUpgrErr<::Error>) { - self.inner.as_mut().expect("Can't receive an outbound substream if disabled; QED") + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + err: ProtocolsHandlerUpgrErr<::Error>, + ) { + self.inner + .as_mut() + .expect("Can't receive an outbound substream if disabled; QED") .inject_dial_upgrade_error(info, err) } - fn inject_listen_upgrade_error(&mut self, info: Self::InboundOpenInfo, err: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_listen_upgrade_error( + &mut self, + info: Self::InboundOpenInfo, + err: ProtocolsHandlerUpgrErr<::Error>, + ) { let (inner, info) = match (self.inner.as_mut(), info) { (Some(inner), Either::Left(info)) => (inner, info), // Ignore listen upgrade errors in disabled state. @@ -313,24 +359,26 @@ where "Unexpected `Either::Left` inbound info through \ `inject_listen_upgrade_error` in disabled state.", ), - }; let err = match err { ProtocolsHandlerUpgrErr::Timeout => ProtocolsHandlerUpgrErr::Timeout, ProtocolsHandlerUpgrErr::Timer => ProtocolsHandlerUpgrErr::Timer, - ProtocolsHandlerUpgrErr::Upgrade(err) => + ProtocolsHandlerUpgrErr::Upgrade(err) => { ProtocolsHandlerUpgrErr::Upgrade(err.map_err(|err| match err { EitherError::A(e) => e, - EitherError::B(v) => void::unreachable(v) + EitherError::B(v) => void::unreachable(v), })) + } }; inner.inject_listen_upgrade_error(info, err) } fn connection_keep_alive(&self) -> KeepAlive { - self.inner.as_ref().map(|h| h.connection_keep_alive()) + self.inner + .as_ref() + .map(|h| h.connection_keep_alive()) .unwrap_or(KeepAlive::No) } @@ -338,7 +386,12 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { if let Some(inner) = self.inner.as_mut() { inner.poll(cx) @@ -369,9 +422,7 @@ mod tests { /// [`ToggleProtoHandler`] should ignore the error in both of these cases. #[test] fn ignore_listen_upgrade_error_when_disabled() { - let mut handler = ToggleProtoHandler:: { - inner: None, - }; + let mut handler = ToggleProtoHandler:: { inner: None }; handler.inject_listen_upgrade_error(Either::Right(()), ProtocolsHandlerUpgrErr::Timeout); } diff --git a/transports/deflate/src/lib.rs b/transports/deflate/src/lib.rs index d93e6ed2e39..698b6cab6a9 100644 --- a/transports/deflate/src/lib.rs +++ b/transports/deflate/src/lib.rs @@ -105,11 +105,12 @@ impl DeflateOutput { /// Tries to write the content of `self.write_out` to `self.inner`. /// Returns `Ready(Ok(()))` if `self.write_out` is empty. fn flush_write_out(&mut self, cx: &mut Context<'_>) -> Poll> - where S: AsyncWrite + Unpin + where + S: AsyncWrite + Unpin, { loop { if self.write_out.is_empty() { - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } match AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &self.write_out) { @@ -123,9 +124,14 @@ impl DeflateOutput { } impl AsyncRead for DeflateOutput - where S: AsyncRead + Unpin +where + S: AsyncRead + Unpin, { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { // We use a `this` variable because the compiler doesn't allow multiple mutable borrows // across a `Deref`. let this = &mut *self; @@ -133,31 +139,38 @@ impl AsyncRead for DeflateOutput loop { // Read from `self.inner` into `self.read_interm` if necessary. if this.read_interm.is_empty() && !this.inner_read_eof { - this.read_interm.resize(this.read_interm.capacity() + 256, 0); + this.read_interm + .resize(this.read_interm.capacity() + 256, 0); match AsyncRead::poll_read(Pin::new(&mut this.inner), cx, &mut this.read_interm) { Poll::Ready(Ok(0)) => { this.inner_read_eof = true; this.read_interm.clear(); } - Poll::Ready(Ok(n)) => { - this.read_interm.truncate(n) - }, + Poll::Ready(Ok(n)) => this.read_interm.truncate(n), Poll::Ready(Err(err)) => { this.read_interm.clear(); - return Poll::Ready(Err(err)) - }, + return Poll::Ready(Err(err)); + } Poll::Pending => { this.read_interm.clear(); - return Poll::Pending - }, + return Poll::Pending; + } } } debug_assert!(!this.read_interm.is_empty() || this.inner_read_eof); let before_out = this.decompress.total_out(); let before_in = this.decompress.total_in(); - let ret = this.decompress.decompress(&this.read_interm, buf, if this.inner_read_eof { flate2::FlushDecompress::Finish } else { flate2::FlushDecompress::None })?; + let ret = this.decompress.decompress( + &this.read_interm, + buf, + if this.inner_read_eof { + flate2::FlushDecompress::Finish + } else { + flate2::FlushDecompress::None + }, + )?; // Remove from `self.read_interm` the bytes consumed by the decompressor. let consumed = (this.decompress.total_in() - before_in) as usize; @@ -165,18 +178,21 @@ impl AsyncRead for DeflateOutput let read = (this.decompress.total_out() - before_out) as usize; if read != 0 || ret == flate2::Status::StreamEnd { - return Poll::Ready(Ok(read)) + return Poll::Ready(Ok(read)); } } } } impl AsyncWrite for DeflateOutput - where S: AsyncWrite + Unpin +where + S: AsyncWrite + Unpin, { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) - -> Poll> - { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { // We use a `this` variable because the compiler doesn't allow multiple mutable borrows // across a `Deref`. let this = &mut *self; @@ -195,8 +211,12 @@ impl AsyncWrite for DeflateOutput // Instead, we invoke the compressor in a loop until it accepts some of our data. loop { let before_in = this.compress.total_in(); - this.write_out.reserve(256); // compress_vec uses the Vec's capacity - let ret = this.compress.compress_vec(buf, &mut this.write_out, flate2::FlushCompress::None)?; + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + let ret = this.compress.compress_vec( + buf, + &mut this.write_out, + flate2::FlushCompress::None, + )?; let written = (this.compress.total_in() - before_in) as usize; if written != 0 || ret == flate2::Status::StreamEnd { @@ -211,15 +231,17 @@ impl AsyncWrite for DeflateOutput let this = &mut *self; ready!(this.flush_write_out(cx))?; - this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Sync)?; + this.compress + .compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Sync)?; loop { ready!(this.flush_write_out(cx))?; debug_assert!(this.write_out.is_empty()); // We ask the compressor to flush everything into `self.write_out`. - this.write_out.reserve(256); // compress_vec uses the Vec's capacity - this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::None)?; + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + this.compress + .compress_vec(&[], &mut this.write_out, flate2::FlushCompress::None)?; if this.write_out.is_empty() { break; } @@ -238,8 +260,9 @@ impl AsyncWrite for DeflateOutput // We ask the compressor to flush everything into `self.write_out`. debug_assert!(this.write_out.is_empty()); - this.write_out.reserve(256); // compress_vec uses the Vec's capacity - this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Finish)?; + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + this.compress + .compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Finish)?; if this.write_out.is_empty() { break; } diff --git a/transports/deflate/tests/test.rs b/transports/deflate/tests/test.rs index 896fb491349..6027c4f4afb 100644 --- a/transports/deflate/tests/test.rs +++ b/transports/deflate/tests/test.rs @@ -28,7 +28,7 @@ use quickcheck::{QuickCheck, RngCore, TestResult}; fn deflate() { fn prop(message: Vec) -> TestResult { if message.is_empty() { - return TestResult::discard() + return TestResult::discard(); } async_std::task::block_on(run(message)); TestResult::passed() @@ -44,16 +44,24 @@ fn lot_of_data() { } async fn run(message1: Vec) { - let transport = TcpConfig::new() - .and_then(|conn, endpoint| { - upgrade::apply(conn, DeflateConfig::default(), endpoint, upgrade::Version::V1) - }); + let transport = TcpConfig::new().and_then(|conn, endpoint| { + upgrade::apply( + conn, + DeflateConfig::default(), + endpoint, + upgrade::Version::V1, + ) + }); - let mut listener = transport.clone() + let mut listener = transport + .clone() .listen_on("/ip4/0.0.0.0/tcp/0".parse().expect("multiaddr")) .expect("listener"); - let listen_addr = listener.by_ref().next().await + let listen_addr = listener + .by_ref() + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -82,7 +90,11 @@ async fn run(message1: Vec) { conn.close().await.expect("close") }); - let mut conn = transport.dial(listen_addr).expect("dialer").await.expect("connection"); + let mut conn = transport + .dial(listen_addr) + .expect("dialer") + .await + .expect("connection"); conn.write_all(&message1).await.expect("write_all"); conn.close().await.expect("close"); diff --git a/transports/dns/src/lib.rs b/transports/dns/src/lib.rs index 499f33e8e5e..6174c1e362c 100644 --- a/transports/dns/src/lib.rs +++ b/transports/dns/src/lib.rs @@ -54,27 +54,23 @@ //! //![trust-dns-resolver]: https://docs.rs/trust-dns-resolver/latest/trust_dns_resolver/#dns-over-tls-and-dns-over-https -use futures::{prelude::*, future::BoxFuture}; +#[cfg(feature = "async-std")] +use async_std_resolver::{AsyncStdConnection, AsyncStdConnectionProvider}; +use futures::{future::BoxFuture, prelude::*}; use libp2p_core::{ + multiaddr::{Multiaddr, Protocol}, + transport::{ListenerEvent, TransportError}, Transport, - multiaddr::{Protocol, Multiaddr}, - transport::{TransportError, ListenerEvent} }; use smallvec::SmallVec; -use std::{convert::TryFrom, error, fmt, iter, net::IpAddr, str}; #[cfg(any(feature = "async-std", feature = "tokio"))] use std::io; +use std::{convert::TryFrom, error, fmt, iter, net::IpAddr, str}; #[cfg(any(feature = "async-std", feature = "tokio"))] use trust_dns_resolver::system_conf; -use trust_dns_resolver::{ - AsyncResolver, - ConnectionProvider, - proto::xfer::dns_handle::DnsHandle, -}; +use trust_dns_resolver::{proto::xfer::dns_handle::DnsHandle, AsyncResolver, ConnectionProvider}; #[cfg(feature = "tokio")] use trust_dns_resolver::{TokioAsyncResolver, TokioConnection, TokioConnectionProvider}; -#[cfg(feature = "async-std")] -use async_std_resolver::{AsyncStdConnection, AsyncStdConnectionProvider}; pub use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; pub use trust_dns_resolver::error::{ResolveError, ResolveErrorKind}; @@ -112,7 +108,7 @@ pub type TokioDnsConfig = GenDnsConfig where C: DnsHandle, - P: ConnectionProvider + P: ConnectionProvider, { /// The underlying transport. inner: T, @@ -129,12 +125,14 @@ impl DnsConfig { } /// Creates a [`DnsConfig`] with a custom resolver configuration and options. - pub async fn custom(inner: T, cfg: ResolverConfig, opts: ResolverOpts) - -> Result, io::Error> - { + pub async fn custom( + inner: T, + cfg: ResolverConfig, + opts: ResolverOpts, + ) -> Result, io::Error> { Ok(DnsConfig { inner, - resolver: async_std_resolver::resolver(cfg, opts).await? + resolver: async_std_resolver::resolver(cfg, opts).await?, }) } } @@ -149,12 +147,14 @@ impl TokioDnsConfig { /// Creates a [`TokioDnsConfig`] with a custom resolver configuration /// and options. - pub fn custom(inner: T, cfg: ResolverConfig, opts: ResolverOpts) - -> Result, io::Error> - { + pub fn custom( + inner: T, + cfg: ResolverConfig, + opts: ResolverOpts, + ) -> Result, io::Error> { Ok(TokioDnsConfig { inner, - resolver: TokioAsyncResolver::tokio(cfg, opts)? + resolver: TokioAsyncResolver::tokio(cfg, opts)?, }) } } @@ -181,24 +181,29 @@ where type Output = T::Output; type Error = DnsErr; type Listener = stream::MapErr< - stream::MapOk) - -> ListenerEvent>, - fn(T::Error) -> Self::Error>; + stream::MapOk< + T::Listener, + fn( + ListenerEvent, + ) -> ListenerEvent, + >, + fn(T::Error) -> Self::Error, + >; type ListenerUpgrade = future::MapErr Self::Error>; type Dial = future::Either< future::MapErr Self::Error>, - BoxFuture<'static, Result> + BoxFuture<'static, Result>, >; fn listen_on(self, addr: Multiaddr) -> Result> { - let listener = self.inner.listen_on(addr).map_err(|err| err.map(DnsErr::Transport))?; + let listener = self + .inner + .listen_on(addr) + .map_err(|err| err.map(DnsErr::Transport))?; let listener = listener .map_ok::<_, fn(_) -> _>(|event| { event - .map(|upgr| { - upgr.map_err::<_, fn(_) -> _>(DnsErr::Transport) - }) + .map(|upgr| upgr.map_err::<_, fn(_) -> _>(DnsErr::Transport)) .map_err(DnsErr::Transport) }) .map_err::<_, fn(_) -> _>(DnsErr::Transport); @@ -225,24 +230,24 @@ where // address. while let Some(addr) = unresolved.pop() { if let Some((i, name)) = addr.iter().enumerate().find(|(_, p)| match p { - Protocol::Dns(_) | - Protocol::Dns4(_) | - Protocol::Dns6(_) | - Protocol::Dnsaddr(_) => true, - _ => false + Protocol::Dns(_) + | Protocol::Dns4(_) + | Protocol::Dns6(_) + | Protocol::Dnsaddr(_) => true, + _ => false, }) { if dns_lookups == MAX_DNS_LOOKUPS { log::debug!("Too many DNS lookups. Dropping unresolved {}.", addr); last_err = Some(DnsErr::TooManyLookups); // There may still be fully resolved addresses in `unresolved`, // so keep going until `unresolved` is empty. - continue + continue; } dns_lookups += 1; match resolve(&name, &resolver).await { Err(e) => { if unresolved.is_empty() { - return Err(e) + return Err(e); } // If there are still unresolved addresses, there is // a chance of success, but we track the last error. @@ -256,7 +261,8 @@ where Ok(Resolved::Many(ips)) => { for ip in ips { log::trace!("Resolved {} -> {}", name, ip); - let addr = addr.replace(i, |_| Some(ip)).expect("`i` is a valid index"); + let addr = + addr.replace(i, |_| Some(ip)).expect("`i` is a valid index"); unresolved.push(addr); } } @@ -269,10 +275,14 @@ where if n < MAX_TXT_RECORDS { n += 1; log::trace!("Resolved {} -> {}", name, a); - let addr = prefix.iter().chain(a.iter()).collect::(); + let addr = + prefix.iter().chain(a.iter()).collect::(); unresolved.push(addr); } else { - log::debug!("Too many TXT records. Dropping resolved {}.", a); + log::debug!( + "Too many TXT records. Dropping resolved {}.", + a + ); } } } @@ -291,9 +301,10 @@ where dial_attempts += 1; out.await.map_err(DnsErr::Transport) } - Err(TransportError::MultiaddrNotSupported(a)) => - Err(DnsErr::MultiaddrNotSupported(a)), - Err(TransportError::Other(err)) => Err(DnsErr::Transport(err)) + Err(TransportError::MultiaddrNotSupported(a)) => { + Err(DnsErr::MultiaddrNotSupported(a)) + } + Err(TransportError::Other(err)) => Err(DnsErr::Transport(err)), }; match result { @@ -301,11 +312,14 @@ where Err(err) => { log::debug!("Dial error: {:?}.", err); if unresolved.is_empty() { - return Err(err) + return Err(err); } if dial_attempts == MAX_DIAL_ATTEMPTS { - log::debug!("Aborting dialing after {} attempts.", MAX_DIAL_ATTEMPTS); - return Err(err) + log::debug!( + "Aborting dialing after {} attempts.", + MAX_DIAL_ATTEMPTS + ); + return Err(err); } last_err = Some(err); } @@ -317,10 +331,12 @@ where // attempt, return that error. Otherwise there were no valid DNS records // for the given address to begin with (i.e. DNS lookups succeeded but // produced no records relevant for the given `addr`). - Err(last_err.unwrap_or_else(|| - DnsErr::ResolveError( - ResolveErrorKind::Message("No matching records found.").into()))) - }.boxed().right_future()) + Err(last_err.unwrap_or_else(|| { + DnsErr::ResolveError(ResolveErrorKind::Message("No matching records found.").into()) + })) + } + .boxed() + .right_future()) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { @@ -348,7 +364,8 @@ pub enum DnsErr { } impl fmt::Display for DnsErr -where TErr: fmt::Display +where + TErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -361,7 +378,8 @@ where TErr: fmt::Display } impl error::Error for DnsErr -where TErr: error::Error + 'static +where + TErr: error::Error + 'static, { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { @@ -393,96 +411,111 @@ enum Resolved<'a> { /// [`Resolved::One`]. fn resolve<'a, E: 'a + Send, C, P>( proto: &Protocol<'a>, - resolver: &'a AsyncResolver, + resolver: &'a AsyncResolver, ) -> BoxFuture<'a, Result, DnsErr>> where C: DnsHandle, P: ConnectionProvider, { match proto { - Protocol::Dns(ref name) => { - resolver.lookup_ip(name.clone().into_owned()).map(move |res| match res { + Protocol::Dns(ref name) => resolver + .lookup_ip(name.clone().into_owned()) + .map(move |res| match res { Ok(ips) => { let mut ips = ips.into_iter(); - let one = ips.next() + let one = ips + .next() .expect("If there are no results, `Err(NoRecordsFound)` is expected."); if let Some(two) = ips.next() { Ok(Resolved::Many( - iter::once(one).chain(iter::once(two)) + iter::once(one) + .chain(iter::once(two)) .chain(ips) .map(Protocol::from) - .collect())) + .collect(), + )) } else { Ok(Resolved::One(Protocol::from(one))) } } - Err(e) => Err(DnsErr::ResolveError(e)) - }).boxed() - } - Protocol::Dns4(ref name) => { - resolver.ipv4_lookup(name.clone().into_owned()).map(move |res| match res { + Err(e) => Err(DnsErr::ResolveError(e)), + }) + .boxed(), + Protocol::Dns4(ref name) => resolver + .ipv4_lookup(name.clone().into_owned()) + .map(move |res| match res { Ok(ips) => { let mut ips = ips.into_iter(); - let one = ips.next() + let one = ips + .next() .expect("If there are no results, `Err(NoRecordsFound)` is expected."); if let Some(two) = ips.next() { Ok(Resolved::Many( - iter::once(one).chain(iter::once(two)) + iter::once(one) + .chain(iter::once(two)) .chain(ips) .map(IpAddr::from) .map(Protocol::from) - .collect())) + .collect(), + )) } else { Ok(Resolved::One(Protocol::from(IpAddr::from(one)))) } } - Err(e) => Err(DnsErr::ResolveError(e)) - }).boxed() - } - Protocol::Dns6(ref name) => { - resolver.ipv6_lookup(name.clone().into_owned()).map(move |res| match res { + Err(e) => Err(DnsErr::ResolveError(e)), + }) + .boxed(), + Protocol::Dns6(ref name) => resolver + .ipv6_lookup(name.clone().into_owned()) + .map(move |res| match res { Ok(ips) => { let mut ips = ips.into_iter(); - let one = ips.next() + let one = ips + .next() .expect("If there are no results, `Err(NoRecordsFound)` is expected."); if let Some(two) = ips.next() { Ok(Resolved::Many( - iter::once(one).chain(iter::once(two)) + iter::once(one) + .chain(iter::once(two)) .chain(ips) .map(IpAddr::from) .map(Protocol::from) - .collect())) + .collect(), + )) } else { Ok(Resolved::One(Protocol::from(IpAddr::from(one)))) } } - Err(e) => Err(DnsErr::ResolveError(e)) - }).boxed() - }, + Err(e) => Err(DnsErr::ResolveError(e)), + }) + .boxed(), Protocol::Dnsaddr(ref name) => { let name = [DNSADDR_PREFIX, name].concat(); - resolver.txt_lookup(name).map(move |res| match res { - Ok(txts) => { - let mut addrs = Vec::new(); - for txt in txts { - if let Some(chars) = txt.txt_data().first() { - match parse_dnsaddr_txt(chars) { - Err(e) => { - // Skip over seemingly invalid entries. - log::debug!("Invalid TXT record: {:?}", e); - } - Ok(a) => { - addrs.push(a); + resolver + .txt_lookup(name) + .map(move |res| match res { + Ok(txts) => { + let mut addrs = Vec::new(); + for txt in txts { + if let Some(chars) = txt.txt_data().first() { + match parse_dnsaddr_txt(chars) { + Err(e) => { + // Skip over seemingly invalid entries. + log::debug!("Invalid TXT record: {:?}", e); + } + Ok(a) => { + addrs.push(a); + } } } } + Ok(Resolved::Addrs(addrs)) } - Ok(Resolved::Addrs(addrs)) - } - Err(e) => Err(DnsErr::ResolveError(e)) - }).boxed() + Err(e) => Err(DnsErr::ResolveError(e)), + }) + .boxed() } - proto => future::ready(Ok(Resolved::One(proto.clone()))).boxed() + proto => future::ready(Ok(Resolved::One(proto.clone()))).boxed(), } } @@ -491,7 +524,7 @@ fn parse_dnsaddr_txt(txt: &[u8]) -> io::Result { let s = str::from_utf8(txt).map_err(invalid_data)?; match s.strip_prefix("dnsaddr=") { None => Err(invalid_data("Missing `dnsaddr=` prefix.")), - Some(a) => Ok(Multiaddr::try_from(a).map_err(invalid_data)?) + Some(a) => Ok(Multiaddr::try_from(a).map_err(invalid_data)?), } } @@ -504,11 +537,10 @@ mod tests { use super::*; use futures::{future::BoxFuture, stream::BoxStream}; use libp2p_core::{ - Transport, - PeerId, - multiaddr::{Protocol, Multiaddr}, + multiaddr::{Multiaddr, Protocol}, transport::ListenerEvent, transport::TransportError, + PeerId, Transport, }; #[test] @@ -521,19 +553,27 @@ mod tests { impl Transport for CustomTransport { type Output = (); type Error = std::io::Error; - type Listener = BoxStream<'static, Result, Self::Error>>; + type Listener = BoxStream< + 'static, + Result, Self::Error>, + >; type ListenerUpgrade = BoxFuture<'static, Result>; type Dial = BoxFuture<'static, Result>; - fn listen_on(self, _: Multiaddr) -> Result> { + fn listen_on( + self, + _: Multiaddr, + ) -> Result> { unreachable!() } fn dial(self, addr: Multiaddr) -> Result> { // Check that all DNS components have been resolved, i.e. replaced. assert!(!addr.iter().any(|p| match p { - Protocol::Dns(_) | Protocol::Dns4(_) | Protocol::Dns6(_) | Protocol::Dnsaddr(_) - => true, + Protocol::Dns(_) + | Protocol::Dns4(_) + | Protocol::Dns6(_) + | Protocol::Dnsaddr(_) => true, _ => false, })); Ok(Box::pin(future::ready(Ok(())))) @@ -598,13 +638,17 @@ mod tests { // an entry with a random `p2p` suffix. match transport .clone() - .dial(format!("/dnsaddr/bootstrap.libp2p.io/p2p/{}", PeerId::random()).parse().unwrap()) + .dial( + format!("/dnsaddr/bootstrap.libp2p.io/p2p/{}", PeerId::random()) + .parse() + .unwrap(), + ) .unwrap() .await { - Err(DnsErr::ResolveError(_)) => {}, + Err(DnsErr::ResolveError(_)) => {} Err(e) => panic!("Unexpected error: {:?}", e), - Ok(_) => panic!("Unexpected success.") + Ok(_) => panic!("Unexpected success."), } // Failure due to no records. @@ -615,7 +659,7 @@ mod tests { .await { Err(DnsErr::ResolveError(e)) => match e.kind() { - ResolveErrorKind::NoRecordsFound { .. } => {}, + ResolveErrorKind::NoRecordsFound { .. } => {} _ => panic!("Unexpected DNS error: {:?}", e), }, Err(e) => panic!("Unexpected error: {:?}", e), @@ -630,7 +674,7 @@ mod tests { let config = ResolverConfig::quad9(); let opts = ResolverOpts::default(); async_std_crate::task::block_on( - DnsConfig::custom(CustomTransport, config, opts).then(|dns| run(dns.unwrap())) + DnsConfig::custom(CustomTransport, config, opts).then(|dns| run(dns.unwrap())), ); } @@ -645,7 +689,9 @@ mod tests { .enable_time() .build() .unwrap(); - rt.block_on(run(TokioDnsConfig::custom(CustomTransport, config, opts).unwrap())); + rt.block_on(run( + TokioDnsConfig::custom(CustomTransport, config, opts).unwrap() + )); } } } diff --git a/transports/noise/build.rs b/transports/noise/build.rs index b13c29b5197..c9cf60412cd 100644 --- a/transports/noise/build.rs +++ b/transports/noise/build.rs @@ -19,5 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/io/handshake/payload.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/io/handshake/payload.proto"], &["src"]).unwrap(); } diff --git a/transports/noise/src/error.rs b/transports/noise/src/error.rs index 8b836d5ea78..4e1d240fe74 100644 --- a/transports/noise/src/error.rs +++ b/transports/noise/src/error.rs @@ -90,4 +90,3 @@ impl From for NoiseError { NoiseError::SigningError(e) } } - diff --git a/transports/noise/src/io.rs b/transports/noise/src/io.rs index c7bd110c773..37e35ecbeee 100644 --- a/transports/noise/src/io.rs +++ b/transports/noise/src/io.rs @@ -24,11 +24,16 @@ mod framed; pub mod handshake; use bytes::Bytes; -use framed::{MAX_FRAME_LEN, NoiseFramed}; -use futures::ready; +use framed::{NoiseFramed, MAX_FRAME_LEN}; use futures::prelude::*; +use futures::ready; use log::trace; -use std::{cmp::min, fmt, io, pin::Pin, task::{Context, Poll}}; +use std::{ + cmp::min, + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; /// A noise session to a remote. /// @@ -43,9 +48,7 @@ pub struct NoiseOutput { impl fmt::Debug for NoiseOutput { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("NoiseOutput") - .field("io", &self.io) - .finish() + f.debug_struct("NoiseOutput").field("io", &self.io).finish() } } @@ -62,13 +65,17 @@ impl NoiseOutput { } impl AsyncRead for NoiseOutput { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { loop { let len = self.recv_buffer.len(); let off = self.recv_offset; if len > 0 { let n = min(len - off, buf.len()); - buf[.. n].copy_from_slice(&self.recv_buffer[off .. off + n]); + buf[..n].copy_from_slice(&self.recv_buffer[off..off + n]); trace!("read: copied {}/{} bytes", off + n, len); self.recv_offset += n; if len == self.recv_offset { @@ -77,7 +84,7 @@ impl AsyncRead for NoiseOutput { // the buffer when polling for the next frame below. self.recv_buffer = Bytes::new(); } - return Poll::Ready(Ok(n)) + return Poll::Ready(Ok(n)); } match Pin::new(&mut self.io).poll_next(cx) { @@ -94,7 +101,11 @@ impl AsyncRead for NoiseOutput { } impl AsyncWrite for NoiseOutput { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { let this = Pin::into_inner(self); let mut io = Pin::new(&mut this.io); let frame_buf = &mut this.send_buffer; @@ -111,7 +122,7 @@ impl AsyncWrite for NoiseOutput { let n = min(MAX_FRAME_LEN, off.saturating_add(buf.len())); this.send_buffer.resize(n, 0u8); let n = min(MAX_FRAME_LEN - off, buf.len()); - this.send_buffer[off .. off + n].copy_from_slice(&buf[.. n]); + this.send_buffer[off..off + n].copy_from_slice(&buf[..n]); this.send_offset += n; trace!("write: buffered {} bytes", this.send_offset); @@ -134,7 +145,7 @@ impl AsyncWrite for NoiseOutput { io.as_mut().poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>{ + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; Pin::new(&mut self.io).poll_close(cx) } diff --git a/transports/noise/src/io/framed.rs b/transports/noise/src/io/framed.rs index 000300bdfcc..4ca228ebe54 100644 --- a/transports/noise/src/io/framed.rs +++ b/transports/noise/src/io/framed.rs @@ -21,13 +21,17 @@ //! This module provides a `Sink` and `Stream` for length-delimited //! Noise protocol messages in form of [`NoiseFramed`]. -use bytes::{Bytes, BytesMut}; -use crate::{NoiseError, Protocol, PublicKey}; use crate::io::NoiseOutput; -use futures::ready; +use crate::{NoiseError, Protocol, PublicKey}; +use bytes::{Bytes, BytesMut}; use futures::prelude::*; +use futures::ready; use log::{debug, trace}; -use std::{fmt, io, pin::Pin, task::{Context, Poll}}; +use std::{ + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; /// Max. size of a noise message. const MAX_NOISE_MSG_LEN: usize = 65535; @@ -88,14 +92,14 @@ impl NoiseFramed { /// present, cannot be parsed. pub fn into_transport(self) -> Result<(Option>, NoiseOutput), NoiseError> where - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { let dh_remote_pubkey = match self.session.get_remote_static() { None => None, Some(k) => match C::public_from_bytes(k) { Err(e) => return Err(e), - Ok(dh_pk) => Some(dh_pk) - } + Ok(dh_pk) => Some(dh_pk), + }, }; match self.session.into_transport_mode() { Err(e) => Err(e.into()), @@ -129,7 +133,7 @@ enum ReadState { /// The associated result signals if the EOF was unexpected or not. Eof(Result<(), ()>), /// A decryption error occurred (terminal state). - DecErr + DecErr, } /// The states for writing Noise protocol frames. @@ -138,19 +142,23 @@ enum WriteState { /// Ready to write another frame. Ready, /// Writing the frame length. - WriteLen { len: usize, buf: [u8; 2], off: usize }, + WriteLen { + len: usize, + buf: [u8; 2], + off: usize, + }, /// Writing the frame data. WriteData { len: usize, off: usize }, /// EOF has been reached unexpectedly (terminal state). Eof, /// An encryption error occurred (terminal state). - EncErr + EncErr, } impl WriteState { fn is_ready(&self) -> bool { if let WriteState::Ready = self { - return true + return true; } false } @@ -159,7 +167,7 @@ impl WriteState { impl futures::stream::Stream for NoiseFramed where T: AsyncRead + Unpin, - S: SessionState + Unpin + S: SessionState + Unpin, { type Item = io::Result; @@ -169,7 +177,10 @@ where trace!("read state: {:?}", this.read_state); match this.read_state { ReadState::Ready => { - this.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; + this.read_state = ReadState::ReadLen { + buf: [0, 0], + off: 0, + }; } ReadState::ReadLen { mut buf, mut off } => { let n = match read_frame_len(&mut this.io, cx, &mut buf, &mut off) { @@ -177,11 +188,9 @@ where Poll::Ready(Ok(None)) => { trace!("read: eof"); this.read_state = ReadState::Eof(Ok(())); - return Poll::Ready(None) - } - Poll::Ready(Err(e)) => { - return Poll::Ready(Some(Err(e))) + return Poll::Ready(None); } + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), Poll::Pending => { this.read_state = ReadState::ReadLen { buf, off }; return Poll::Pending; @@ -191,14 +200,18 @@ where if n == 0 { trace!("read: empty frame"); this.read_state = ReadState::Ready; - continue + continue; } this.read_buffer.resize(usize::from(n), 0u8); - this.read_state = ReadState::ReadData { len: usize::from(n), off: 0 } + this.read_state = ReadState::ReadData { + len: usize::from(n), + off: 0, + } } ReadState::ReadData { len, ref mut off } => { let n = { - let f = Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off .. len]); + let f = + Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off..len]); match ready!(f) { Ok(n) => n, Err(e) => return Poll::Ready(Some(Err(e))), @@ -208,13 +221,16 @@ where if n == 0 { trace!("read: eof"); this.read_state = ReadState::Eof(Err(())); - return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))) + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); } *off += n; if len == *off { trace!("read: decrypting {} bytes", len); this.decrypt_buffer.resize(len, 0); - if let Ok(n) = this.session.read_message(&this.read_buffer, &mut this.decrypt_buffer) { + if let Ok(n) = this + .session + .read_message(&this.read_buffer, &mut this.decrypt_buffer) + { this.decrypt_buffer.truncate(n); trace!("read: payload len = {} bytes", n); this.read_state = ReadState::Ready; @@ -223,23 +239,25 @@ where // read, the `BytesMut` will reuse the same buffer // for the next frame. let view = this.decrypt_buffer.split().freeze(); - return Poll::Ready(Some(Ok(view))) + return Poll::Ready(Some(Ok(view))); } else { debug!("read: decryption error"); this.read_state = ReadState::DecErr; - return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))) + return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))); } } } ReadState::Eof(Ok(())) => { trace!("read: eof"); - return Poll::Ready(None) + return Poll::Ready(None); } ReadState::Eof(Err(())) => { trace!("read: eof (unexpected)"); - return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))) + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); + } + ReadState::DecErr => { + return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))) } - ReadState::DecErr => return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))) } } } @@ -248,7 +266,7 @@ where impl futures::sink::Sink<&Vec> for NoiseFramed where T: AsyncWrite + Unpin, - S: SessionState + Unpin + S: SessionState + Unpin, { type Error = io::Error; @@ -267,21 +285,20 @@ where Poll::Ready(Ok(false)) => { trace!("write: eof"); this.write_state = WriteState::Eof; - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) - } - Poll::Ready(Err(e)) => { - return Poll::Ready(Err(e)) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => { this.write_state = WriteState::WriteLen { len, buf, off }; - return Poll::Pending + return Poll::Pending; } } this.write_state = WriteState::WriteData { len, off: 0 } } WriteState::WriteData { len, ref mut off } => { let n = { - let f = Pin::new(&mut this.io).poll_write(cx, &this.write_buffer[*off .. len]); + let f = + Pin::new(&mut this.io).poll_write(cx, &this.write_buffer[*off..len]); match ready!(f) { Ok(n) => n, Err(e) => return Poll::Ready(Err(e)), @@ -290,7 +307,7 @@ where if n == 0 { trace!("write: eof"); this.write_state = WriteState::Eof; - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } *off += n; trace!("write: {}/{} bytes written", *off, len); @@ -301,9 +318,9 @@ where } WriteState::Eof => { trace!("write: eof"); - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } - WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())) + WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())), } } } @@ -313,15 +330,19 @@ where let mut this = Pin::into_inner(self); assert!(this.write_state.is_ready()); - this.write_buffer.resize(frame.len() + EXTRA_ENCRYPT_SPACE, 0u8); - match this.session.write_message(frame, &mut this.write_buffer[..]) { + this.write_buffer + .resize(frame.len() + EXTRA_ENCRYPT_SPACE, 0u8); + match this + .session + .write_message(frame, &mut this.write_buffer[..]) + { Ok(n) => { trace!("write: cipher text len = {} bytes", n); this.write_buffer.truncate(n); this.write_state = WriteState::WriteLen { len: n, buf: u16::to_be_bytes(n as u16), - off: 0 + off: 0, }; Ok(()) } @@ -386,7 +407,7 @@ fn read_frame_len( off: &mut usize, ) -> Poll>> { loop { - match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off ..])) { + match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off..])) { Ok(n) => { if n == 0 { return Poll::Ready(Ok(None)); @@ -395,10 +416,10 @@ fn read_frame_len( if *off == 2 { return Poll::Ready(Ok(Some(u16::from_be_bytes(*buf)))); } - }, + } Err(e) => { return Poll::Ready(Err(e)); - }, + } } } } @@ -419,14 +440,14 @@ fn write_frame_len( off: &mut usize, ) -> Poll> { loop { - match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off ..])) { + match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off..])) { Ok(n) => { if n == 0 { - return Poll::Ready(Ok(false)) + return Poll::Ready(Ok(false)); } *off += n; if *off == 2 { - return Poll::Ready(Ok(true)) + return Poll::Ready(Ok(true)); } } Err(e) => { @@ -435,4 +456,3 @@ fn write_frame_len( } } } - diff --git a/transports/noise/src/io/handshake.rs b/transports/noise/src/io/handshake.rs index 21faa84d100..fa97798fb23 100644 --- a/transports/noise/src/io/handshake.rs +++ b/transports/noise/src/io/handshake.rs @@ -24,14 +24,14 @@ mod payload_proto { include!(concat!(env!("OUT_DIR"), "/payload.proto.rs")); } -use bytes::Bytes; -use crate::LegacyConfig; use crate::error::NoiseError; -use crate::protocol::{Protocol, PublicKey, KeypairIdentity}; -use crate::io::{NoiseOutput, framed::NoiseFramed}; -use libp2p_core::identity; +use crate::io::{framed::NoiseFramed, NoiseOutput}; +use crate::protocol::{KeypairIdentity, Protocol, PublicKey}; +use crate::LegacyConfig; +use bytes::Bytes; use futures::prelude::*; use futures::task; +use libp2p_core::identity; use prost::Message; use std::{io, pin::Pin, task::Context}; @@ -59,7 +59,7 @@ pub enum RemoteIdentity { /// > **Note**: To rule out active attacks like a MITM, trust in the public key must /// > still be established, e.g. by comparing the key against an expected or /// > otherwise known public key. - IdentityKey(identity::PublicKey) + IdentityKey(identity::PublicKey), } /// The options for identity exchange in an authenticated handshake. @@ -87,14 +87,12 @@ pub enum IdentityExchange { /// /// The remote identity is known, thus identities must be mutually known /// in order for the handshake to succeed. - None { remote: identity::PublicKey } + None { remote: identity::PublicKey }, } /// A future performing a Noise handshake pattern. pub struct Handshake( - Pin, NoiseOutput), NoiseError>, - > + Send>> + Pin, NoiseOutput), NoiseError>> + Send>>, ); impl Future for Handshake { @@ -131,7 +129,7 @@ pub fn rt1_initiator( ) -> Handshake where T: AsyncWrite + AsyncRead + Send + Unpin + 'static, - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { Handshake(Box::pin(async move { let mut state = State::new(io, session, identity, identity_x, legacy)?; @@ -166,7 +164,7 @@ pub fn rt1_responder( ) -> Handshake where T: AsyncWrite + AsyncRead + Send + Unpin + 'static, - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { Handshake(Box::pin(async move { let mut state = State::new(io, session, identity, identity_x, legacy)?; @@ -203,7 +201,7 @@ pub fn rt15_initiator( ) -> Handshake where T: AsyncWrite + AsyncRead + Unpin + Send + 'static, - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { Handshake(Box::pin(async move { let mut state = State::new(io, session, identity, identity_x, legacy)?; @@ -241,7 +239,7 @@ pub fn rt15_responder( ) -> Handshake where T: AsyncWrite + AsyncRead + Unpin + Send + 'static, - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { Handshake(Box::pin(async move { let mut state = State::new(io, session, identity, identity_x, legacy)?; @@ -289,28 +287,25 @@ impl State { IdentityExchange::Mutual => (None, true), IdentityExchange::Send { remote } => (Some(remote), true), IdentityExchange::Receive => (None, false), - IdentityExchange::None { remote } => (Some(remote), false) + IdentityExchange::None { remote } => (Some(remote), false), }; - session.map(|s| - State { - identity, - io: NoiseFramed::new(io, s), - dh_remote_pubkey_sig: None, - id_remote_pubkey, - send_identity, - legacy, - } - ) + session.map(|s| State { + identity, + io: NoiseFramed::new(io, s), + dh_remote_pubkey_sig: None, + id_remote_pubkey, + send_identity, + legacy, + }) } } -impl State -{ +impl State { /// Finish a handshake, yielding the established remote identity and the /// [`NoiseOutput`] for communicating on the encrypted channel. fn finish(self) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> where - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { let (pubkey, io) = self.io.into_transport()?; let remote = match (self.id_remote_pubkey, pubkey) { @@ -320,7 +315,7 @@ impl State if C::verify(&id_pk, &dh_pk, &self.dh_remote_pubkey_sig) { RemoteIdentity::IdentityKey(id_pk) } else { - return Err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey); } } }; @@ -334,7 +329,7 @@ impl State /// A future for receiving a Noise handshake message. async fn recv(state: &mut State) -> Result where - T: AsyncRead + Unpin + T: AsyncRead + Unpin, { match state.io.next().await { None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "eof").into()), @@ -346,13 +341,13 @@ where /// A future for receiving a Noise handshake message with an empty payload. async fn recv_empty(state: &mut State) -> Result<(), NoiseError> where - T: AsyncRead + Unpin + T: AsyncRead + Unpin, { let msg = recv(state).await?; if !msg.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Unexpected handshake payload.").into()) + return Err( + io::Error::new(io::ErrorKind::InvalidData, "Unexpected handshake payload.").into(), + ); } Ok(()) } @@ -360,7 +355,7 @@ where /// A future for sending a Noise handshake message with an empty payload. async fn send_empty(state: &mut State) -> Result<(), NoiseError> where - T: AsyncWrite + Unpin + T: AsyncWrite + Unpin, { state.io.send(&Vec::new()).await?; Ok(()) @@ -390,12 +385,12 @@ where pb_result = pb_result.or_else(|e| { if msg.len() > 2 { let mut buf = [0, 0]; - buf.copy_from_slice(&msg[.. 2]); + buf.copy_from_slice(&msg[..2]); // If there is a second length it must be 2 bytes shorter than the // frame length, because each length is encoded as a `u16`. if usize::from(u16::from_be_bytes(buf)) + 2 == msg.len() { log::debug!("Attempting fallback legacy protobuf decoding."); - payload_proto::NoiseHandshakePayload::decode(&msg[2 ..]) + payload_proto::NoiseHandshakePayload::decode(&msg[2..]) } else { Err(e) } @@ -411,7 +406,7 @@ where .map_err(|_| NoiseError::InvalidKey)?; if let Some(ref k) = state.id_remote_pubkey { if k != &pk { - return Err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey); } } state.id_remote_pubkey = Some(pk); @@ -439,16 +434,16 @@ where pb.identity_sig = sig.clone() } - let mut msg = - if state.legacy.send_legacy_handshake { - let mut msg = Vec::with_capacity(2 + pb.encoded_len()); - msg.extend_from_slice(&(pb.encoded_len() as u16).to_be_bytes()); - msg - } else { - Vec::with_capacity(pb.encoded_len()) - }; + let mut msg = if state.legacy.send_legacy_handshake { + let mut msg = Vec::with_capacity(2 + pb.encoded_len()); + msg.extend_from_slice(&(pb.encoded_len() as u16).to_be_bytes()); + msg + } else { + Vec::with_capacity(pb.encoded_len()) + }; - pb.encode(&mut msg).expect("Vec provides capacity as needed"); + pb.encode(&mut msg) + .expect("Vec provides capacity as needed"); state.io.send(&msg).await?; Ok(()) diff --git a/transports/noise/src/lib.rs b/transports/noise/src/lib.rs index 3cb685454bb..d6141483a40 100644 --- a/transports/noise/src/lib.rs +++ b/transports/noise/src/lib.rs @@ -59,15 +59,15 @@ mod io; mod protocol; pub use error::NoiseError; -pub use io::NoiseOutput; pub use io::handshake; -pub use io::handshake::{Handshake, RemoteIdentity, IdentityExchange}; -pub use protocol::{Keypair, AuthenticKeypair, KeypairIdentity, PublicKey, SecretKey}; -pub use protocol::{Protocol, ProtocolParams, IX, IK, XX}; +pub use io::handshake::{Handshake, IdentityExchange, RemoteIdentity}; +pub use io::NoiseOutput; pub use protocol::{x25519::X25519, x25519_spec::X25519Spec}; +pub use protocol::{AuthenticKeypair, Keypair, KeypairIdentity, PublicKey, SecretKey}; +pub use protocol::{Protocol, ProtocolParams, IK, IX, XX}; use futures::prelude::*; -use libp2p_core::{identity, PeerId, UpgradeInfo, InboundUpgrade, OutboundUpgrade}; +use libp2p_core::{identity, InboundUpgrade, OutboundUpgrade, PeerId, UpgradeInfo}; use std::pin::Pin; use zeroize::Zeroize; @@ -78,7 +78,7 @@ pub struct NoiseConfig { params: ProtocolParams, legacy: LegacyConfig, remote: R, - _marker: std::marker::PhantomData

+ _marker: std::marker::PhantomData

, } impl NoiseConfig { @@ -97,7 +97,7 @@ impl NoiseConfig { impl NoiseConfig where - C: Protocol + Zeroize + C: Protocol + Zeroize, { /// Create a new `NoiseConfig` for the `IX` handshake pattern. pub fn ix(dh_keys: AuthenticKeypair) -> Self { @@ -106,14 +106,14 @@ where params: C::params_ix(), legacy: LegacyConfig::default(), remote: (), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } impl NoiseConfig where - C: Protocol + Zeroize + C: Protocol + Zeroize, { /// Create a new `NoiseConfig` for the `XX` handshake pattern. pub fn xx(dh_keys: AuthenticKeypair) -> Self { @@ -122,14 +122,14 @@ where params: C::params_xx(), legacy: LegacyConfig::default(), remote: (), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } impl NoiseConfig where - C: Protocol + Zeroize + C: Protocol + Zeroize, { /// Create a new `NoiseConfig` for the `IK` handshake pattern (recipient side). /// @@ -141,14 +141,14 @@ where params: C::params_ik(), legacy: LegacyConfig::default(), remote: (), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } impl NoiseConfig, identity::PublicKey)> where - C: Protocol + Zeroize + C: Protocol + Zeroize, { /// Create a new `NoiseConfig` for the `IK` handshake pattern (initiator side). /// @@ -157,14 +157,14 @@ where pub fn ik_dialer( dh_keys: AuthenticKeypair, remote_id: identity::PublicKey, - remote_dh: PublicKey + remote_dh: PublicKey, ) -> Self { NoiseConfig { dh_keys, params: C::params_ik(), legacy: LegacyConfig::default(), remote: (remote_dh, remote_id), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } @@ -182,14 +182,19 @@ where type Future = Handshake; fn upgrade_inbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - handshake::rt1_responder(socket, session, + handshake::rt1_responder( + socket, + session, self.dh_keys.into_identity(), IdentityExchange::Mutual, - self.legacy) + self.legacy, + ) } } @@ -204,14 +209,19 @@ where type Future = Handshake; fn upgrade_outbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_initiator() .map_err(NoiseError::from); - handshake::rt1_initiator(socket, session, - self.dh_keys.into_identity(), - IdentityExchange::Mutual, - self.legacy) + handshake::rt1_initiator( + socket, + session, + self.dh_keys.into_identity(), + IdentityExchange::Mutual, + self.legacy, + ) } } @@ -228,14 +238,19 @@ where type Future = Handshake; fn upgrade_inbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - handshake::rt15_responder(socket, session, + handshake::rt15_responder( + socket, + session, self.dh_keys.into_identity(), IdentityExchange::Mutual, - self.legacy) + self.legacy, + ) } } @@ -250,14 +265,19 @@ where type Future = Handshake; fn upgrade_outbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_initiator() .map_err(NoiseError::from); - handshake::rt15_initiator(socket, session, + handshake::rt15_initiator( + socket, + session, self.dh_keys.into_identity(), IdentityExchange::Mutual, - self.legacy) + self.legacy, + ) } } @@ -274,14 +294,19 @@ where type Future = Handshake; fn upgrade_inbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - handshake::rt1_responder(socket, session, + handshake::rt1_responder( + socket, + session, self.dh_keys.into_identity(), IdentityExchange::Receive, - self.legacy) + self.legacy, + ) } } @@ -296,15 +321,22 @@ where type Future = Handshake; fn upgrade_outbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .remote_public_key(self.remote.0.as_ref()) .build_initiator() .map_err(NoiseError::from); - handshake::rt1_initiator(socket, session, + handshake::rt1_initiator( + socket, + session, self.dh_keys.into_identity(), - IdentityExchange::Send { remote: self.remote.1 }, - self.legacy) + IdentityExchange::Send { + remote: self.remote.1, + }, + self.legacy, + ) } } @@ -322,12 +354,12 @@ where /// transport for use with a [`Network`](libp2p_core::Network). #[derive(Clone)] pub struct NoiseAuthenticated { - config: NoiseConfig + config: NoiseConfig, } impl UpgradeInfo for NoiseAuthenticated where - NoiseConfig: UpgradeInfo + NoiseConfig: UpgradeInfo, { type Info = as UpgradeInfo>::Info; type InfoIter = as UpgradeInfo>::InfoIter; @@ -339,10 +371,9 @@ where impl InboundUpgrade for NoiseAuthenticated where - NoiseConfig: UpgradeInfo + InboundUpgrade, NoiseOutput), - Error = NoiseError - > + 'static, + NoiseConfig: UpgradeInfo + + InboundUpgrade, NoiseOutput), Error = NoiseError> + + 'static, as InboundUpgrade>::Future: Send, T: AsyncRead + AsyncWrite + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, @@ -352,20 +383,22 @@ where type Future = Pin> + Send>>; fn upgrade_inbound(self, socket: T, info: Self::Info) -> Self::Future { - Box::pin(self.config.upgrade_inbound(socket, info) - .and_then(|(remote, io)| match remote { - RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), - _ => future::err(NoiseError::AuthenticationFailed) - })) + Box::pin( + self.config + .upgrade_inbound(socket, info) + .and_then(|(remote, io)| match remote { + RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), + _ => future::err(NoiseError::AuthenticationFailed), + }), + ) } } impl OutboundUpgrade for NoiseAuthenticated where - NoiseConfig: UpgradeInfo + OutboundUpgrade, NoiseOutput), - Error = NoiseError - > + 'static, + NoiseConfig: UpgradeInfo + + OutboundUpgrade, NoiseOutput), Error = NoiseError> + + 'static, as OutboundUpgrade>::Future: Send, T: AsyncRead + AsyncWrite + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, @@ -375,11 +408,14 @@ where type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: T, info: Self::Info) -> Self::Future { - Box::pin(self.config.upgrade_outbound(socket, info) - .and_then(|(remote, io)| match remote { - RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), - _ => future::err(NoiseError::AuthenticationFailed) - })) + Box::pin( + self.config + .upgrade_outbound(socket, info) + .and_then(|(remote, io)| match remote { + RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), + _ => future::err(NoiseError::AuthenticationFailed), + }), + ) } } diff --git a/transports/noise/src/protocol.rs b/transports/noise/src/protocol.rs index 7c61274acb5..aa2acb150a9 100644 --- a/transports/noise/src/protocol.rs +++ b/transports/noise/src/protocol.rs @@ -92,16 +92,17 @@ pub trait Protocol { #[allow(deprecated)] fn verify(id_pk: &identity::PublicKey, dh_pk: &PublicKey, sig: &Option>) -> bool where - C: AsRef<[u8]> + C: AsRef<[u8]>, { Self::linked(id_pk, dh_pk) - || - sig.as_ref().map_or(false, |s| id_pk.verify(dh_pk.as_ref(), s)) + || sig + .as_ref() + .map_or(false, |s| id_pk.verify(dh_pk.as_ref(), s)) } fn sign(id_keys: &identity::Keypair, dh_pk: &PublicKey) -> Result, NoiseError> where - C: AsRef<[u8]> + C: AsRef<[u8]>, { Ok(id_keys.sign(dh_pk.as_ref())?) } @@ -118,7 +119,7 @@ pub struct Keypair { #[derive(Clone)] pub struct AuthenticKeypair { keypair: Keypair, - identity: KeypairIdentity + identity: KeypairIdentity, } impl AuthenticKeypair { @@ -143,7 +144,7 @@ pub struct KeypairIdentity { /// The public identity key. pub public: identity::PublicKey, /// The signature over the public DH key. - pub signature: Option> + pub signature: Option>, } impl Keypair { @@ -159,19 +160,25 @@ impl Keypair { /// Turn this DH keypair into a [`AuthenticKeypair`], i.e. a DH keypair that /// is authentic w.r.t. the given identity keypair, by signing the DH public key. - pub fn into_authentic(self, id_keys: &identity::Keypair) -> Result, NoiseError> + pub fn into_authentic( + self, + id_keys: &identity::Keypair, + ) -> Result, NoiseError> where T: AsRef<[u8]>, - T: Protocol + T: Protocol, { let sig = T::sign(id_keys, &self.public)?; let identity = KeypairIdentity { public: id_keys.public(), - signature: Some(sig) + signature: Some(sig), }; - Ok(AuthenticKeypair { keypair: self, identity }) + Ok(AuthenticKeypair { + keypair: self, + identity, + }) } } @@ -228,7 +235,10 @@ impl snow::resolvers::CryptoResolver for Resolver { } } - fn resolve_hash(&self, choice: &snow::params::HashChoice) -> Option> { + fn resolve_hash( + &self, + choice: &snow::params::HashChoice, + ) -> Option> { #[cfg(target_arch = "wasm32")] { snow::resolvers::DefaultResolver.resolve_hash(choice) @@ -239,7 +249,10 @@ impl snow::resolvers::CryptoResolver for Resolver { } } - fn resolve_cipher(&self, choice: &snow::params::CipherChoice) -> Option> { + fn resolve_cipher( + &self, + choice: &snow::params::CipherChoice, + ) -> Option> { #[cfg(target_arch = "wasm32")] { snow::resolvers::DefaultResolver.resolve_cipher(choice) diff --git a/transports/noise/src/protocol/x25519.rs b/transports/noise/src/protocol/x25519.rs index c4e79bc33ae..c0d3936ee36 100644 --- a/transports/noise/src/protocol/x25519.rs +++ b/transports/noise/src/protocol/x25519.rs @@ -29,8 +29,8 @@ use lazy_static::lazy_static; use libp2p_core::UpgradeInfo; use libp2p_core::{identity, identity::ed25519}; use rand::Rng; -use sha2::{Sha512, Digest}; -use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; +use sha2::{Digest, Sha512}; +use x25519_dalek::{x25519, X25519_BASEPOINT_BYTES}; use zeroize::Zeroize; use super::*; @@ -40,12 +40,10 @@ lazy_static! { .parse() .map(ProtocolParams) .expect("Invalid protocol name"); - static ref PARAMS_IX: ProtocolParams = "Noise_IX_25519_ChaChaPoly_SHA256" .parse() .map(ProtocolParams) .expect("Invalid protocol name"); - static ref PARAMS_XX: ProtocolParams = "Noise_XX_25519_ChaChaPoly_SHA256" .parse() .map(ProtocolParams) @@ -115,7 +113,7 @@ impl Protocol for X25519 { fn public_from_bytes(bytes: &[u8]) -> Result, NoiseError> { if bytes.len() != 32 { - return Err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey); } let mut pk = [0u8; 32]; pk.copy_from_slice(bytes); @@ -137,7 +135,7 @@ impl Keypair { pub(super) fn default() -> Self { Keypair { secret: SecretKey(X25519([0u8; 32])), - public: PublicKey(X25519([0u8; 32])) + public: PublicKey(X25519([0u8; 32])), } } @@ -170,14 +168,14 @@ impl Keypair { let kp = Keypair::from(SecretKey::from_ed25519(&p.secret())); let id = KeypairIdentity { public: id_keys.public(), - signature: None + signature: None, }; Some(AuthenticKeypair { keypair: kp, - identity: id + identity: id, }) } - _ => None + _ => None, } } } @@ -193,10 +191,13 @@ impl From> for Keypair { impl PublicKey { /// Construct a curve25519 public key from an Ed25519 public key. pub fn from_ed25519(pk: &ed25519::PublicKey) -> Self { - PublicKey(X25519(CompressedEdwardsY(pk.encode()) - .decompress() - .expect("An Ed25519 public key is a valid point by construction.") - .to_montgomery().0)) + PublicKey(X25519( + CompressedEdwardsY(pk.encode()) + .decompress() + .expect("An Ed25519 public key is a valid point by construction.") + .to_montgomery() + .0, + )) } } @@ -227,11 +228,21 @@ impl SecretKey { #[doc(hidden)] impl snow::types::Dh for Keypair { - fn name(&self) -> &'static str { "25519" } - fn pub_len(&self) -> usize { 32 } - fn priv_len(&self) -> usize { 32 } - fn pubkey(&self) -> &[u8] { self.public.as_ref() } - fn privkey(&self) -> &[u8] { self.secret.as_ref() } + fn name(&self) -> &'static str { + "25519" + } + fn pub_len(&self) -> usize { + 32 + } + fn priv_len(&self) -> usize { + 32 + } + fn pubkey(&self) -> &[u8] { + self.public.as_ref() + } + fn privkey(&self) -> &[u8] { + self.secret.as_ref() + } fn set(&mut self, sk: &[u8]) { let mut secret = [0u8; 32]; @@ -251,20 +262,20 @@ impl snow::types::Dh for Keypair { fn dh(&self, pk: &[u8], shared_secret: &mut [u8]) -> Result<(), ()> { let mut p = [0; 32]; - p.copy_from_slice(&pk[.. 32]); + p.copy_from_slice(&pk[..32]); let ss = x25519((self.secret.0).0, p); - shared_secret[.. 32].copy_from_slice(&ss[..]); + shared_secret[..32].copy_from_slice(&ss[..]); Ok(()) } } #[cfg(test)] mod tests { + use super::*; use libp2p_core::identity::ed25519; use quickcheck::*; use sodiumoxide::crypto::sign; use std::os::raw::c_int; - use super::*; use x25519_dalek::StaticSecret; // ed25519 to x25519 keypair conversion must yield the same results as @@ -276,7 +287,8 @@ mod tests { let x25519 = Keypair::from(SecretKey::from_ed25519(&ed25519.secret())); let sodium_sec = ed25519_sk_to_curve25519(&sign::SecretKey(ed25519.encode())); - let sodium_pub = ed25519_pk_to_curve25519(&sign::PublicKey(ed25519.public().encode().clone())); + let sodium_pub = + ed25519_pk_to_curve25519(&sign::PublicKey(ed25519.public().encode().clone())); let our_pub = x25519.public.0; // libsodium does the [clamping] of the scalar upon key construction, @@ -288,8 +300,7 @@ mod tests { // [clamping]: http://www.lix.polytechnique.fr/~smith/ECC/#scalar-clamping let our_sec = StaticSecret::from((x25519.secret.0).0).to_bytes(); - sodium_sec.as_ref() == Some(&our_sec) && - sodium_pub.as_ref() == Some(&our_pub.0) + sodium_sec.as_ref() == Some(&our_sec) && sodium_pub.as_ref() == Some(&our_pub.0) } quickcheck(prop as fn() -> _); @@ -340,4 +351,3 @@ mod tests { } } } - diff --git a/transports/noise/src/protocol/x25519_spec.rs b/transports/noise/src/protocol/x25519_spec.rs index 16e3ffeafee..2f2c24237a6 100644 --- a/transports/noise/src/protocol/x25519_spec.rs +++ b/transports/noise/src/protocol/x25519_spec.rs @@ -23,13 +23,13 @@ //! [libp2p-noise-spec]: https://github.com/libp2p/specs/tree/master/noise use crate::{NoiseConfig, NoiseError, Protocol, ProtocolParams}; -use libp2p_core::UpgradeInfo; use libp2p_core::identity; +use libp2p_core::UpgradeInfo; use rand::Rng; -use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; +use x25519_dalek::{x25519, X25519_BASEPOINT_BYTES}; use zeroize::Zeroize; -use super::{*, x25519::X25519}; +use super::{x25519::X25519, *}; /// Prefix of static key signatures for domain separation. const STATIC_KEY_DOMAIN: &str = "noise-libp2p-static-key:"; @@ -117,32 +117,48 @@ impl Protocol for X25519Spec { fn public_from_bytes(bytes: &[u8]) -> Result, NoiseError> { if bytes.len() != 32 { - return Err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey); } let mut pk = [0u8; 32]; pk.copy_from_slice(bytes); Ok(PublicKey(X25519Spec(pk))) } - fn verify(id_pk: &identity::PublicKey, dh_pk: &PublicKey, sig: &Option>) -> bool - { + fn verify( + id_pk: &identity::PublicKey, + dh_pk: &PublicKey, + sig: &Option>, + ) -> bool { sig.as_ref().map_or(false, |s| { id_pk.verify(&[STATIC_KEY_DOMAIN.as_bytes(), dh_pk.as_ref()].concat(), s) }) } - fn sign(id_keys: &identity::Keypair, dh_pk: &PublicKey) -> Result, NoiseError> { + fn sign( + id_keys: &identity::Keypair, + dh_pk: &PublicKey, + ) -> Result, NoiseError> { Ok(id_keys.sign(&[STATIC_KEY_DOMAIN.as_bytes(), dh_pk.as_ref()].concat())?) } } #[doc(hidden)] impl snow::types::Dh for Keypair { - fn name(&self) -> &'static str { "25519" } - fn pub_len(&self) -> usize { 32 } - fn priv_len(&self) -> usize { 32 } - fn pubkey(&self) -> &[u8] { self.public.as_ref() } - fn privkey(&self) -> &[u8] { self.secret.as_ref() } + fn name(&self) -> &'static str { + "25519" + } + fn pub_len(&self) -> usize { + 32 + } + fn priv_len(&self) -> usize { + 32 + } + fn pubkey(&self) -> &[u8] { + self.public.as_ref() + } + fn privkey(&self) -> &[u8] { + self.secret.as_ref() + } fn set(&mut self, sk: &[u8]) { let mut secret = [0u8; 32]; @@ -162,9 +178,9 @@ impl snow::types::Dh for Keypair { fn dh(&self, pk: &[u8], shared_secret: &mut [u8]) -> Result<(), ()> { let mut p = [0; 32]; - p.copy_from_slice(&pk[.. 32]); + p.copy_from_slice(&pk[..32]); let ss = x25519((self.secret.0).0, p); - shared_secret[.. 32].copy_from_slice(&ss[..]); + shared_secret[..32].copy_from_slice(&ss[..]); Ok(()) } } diff --git a/transports/noise/tests/smoke.rs b/transports/noise/tests/smoke.rs index 4a4c81b5eb8..e1e1e1c0c04 100644 --- a/transports/noise/tests/smoke.rs +++ b/transports/noise/tests/smoke.rs @@ -19,11 +19,16 @@ // DEALINGS IN THE SOFTWARE. use async_io::Async; -use futures::{future::{self, Either}, prelude::*}; +use futures::{ + future::{self, Either}, + prelude::*, +}; use libp2p_core::identity; -use libp2p_core::upgrade::{self, Negotiated, apply_inbound, apply_outbound}; -use libp2p_core::transport::{Transport, ListenerEvent}; -use libp2p_noise::{Keypair, X25519, X25519Spec, NoiseConfig, RemoteIdentity, NoiseError, NoiseOutput}; +use libp2p_core::transport::{ListenerEvent, Transport}; +use libp2p_core::upgrade::{self, apply_inbound, apply_outbound, Negotiated}; +use libp2p_noise::{ + Keypair, NoiseConfig, NoiseError, NoiseOutput, RemoteIdentity, X25519Spec, X25519, +}; use libp2p_tcp::TcpConfig; use log::info; use quickcheck::QuickCheck; @@ -36,7 +41,9 @@ fn core_upgrade_compat() { let id_keys = identity::Keypair::generate_ed25519(); let dh_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); let noise = NoiseConfig::xx(dh_keys).into_authenticated(); - let _ = TcpConfig::new().upgrade(upgrade::Version::V1).authenticate(noise); + let _ = TcpConfig::new() + .upgrade(upgrade::Version::V1) + .authenticate(noise); } #[test] @@ -50,24 +57,40 @@ fn xx_spec() { let server_id_public = server_id.public(); let client_id_public = client_id.public(); - let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); + let server_dh = Keypair::::new() + .into_authentic(&server_id) + .unwrap(); let server_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::xx(server_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::xx(server_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &client_id_public)); - let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); + let client_dh = Keypair::::new() + .into_authentic(&client_id) + .unwrap(); let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::xx(client_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::xx(client_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &server_id_public)); run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec) -> bool) } #[test] @@ -84,21 +107,33 @@ fn xx() { let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); let server_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::xx(server_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::xx(server_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &client_id_public)); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::xx(client_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::xx(client_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &server_id_public)); run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec) -> bool) } #[test] @@ -115,21 +150,33 @@ fn ix() { let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); let server_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::ix(server_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::ix(server_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &client_id_public)); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::ix(client_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::ix(client_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &server_id_public)); run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec) -> bool) } #[test] @@ -150,8 +197,11 @@ fn ik_xx() { if endpoint.is_listener() { Either::Left(apply_inbound(output, NoiseConfig::ik_listener(server_dh))) } else { - Either::Right(apply_outbound(output, NoiseConfig::xx(server_dh), - upgrade::Version::V1)) + Either::Right(apply_outbound( + output, + NoiseConfig::xx(server_dh), + upgrade::Version::V1, + )) } }) .and_then(move |out, _| expect_identity(out, &client_id_public)); @@ -161,9 +211,11 @@ fn ik_xx() { let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { if endpoint.is_dialer() { - Either::Left(apply_outbound(output, + Either::Left(apply_outbound( + output, NoiseConfig::ik_dialer(client_dh, server_id_public, server_dh_public), - upgrade::Version::V1)) + upgrade::Version::V1, + )) } else { Either::Right(apply_inbound(output, NoiseConfig::xx(client_dh))) } @@ -173,7 +225,9 @@ fn ik_xx() { run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec) -> bool) } type Output = (RemoteIdentity, NoiseOutput>>); @@ -188,14 +242,15 @@ where U::Dial: Send + 'static, U::Listener: Send + 'static, U::ListenerUpgrade: Send + 'static, - I: IntoIterator + Clone + I: IntoIterator + Clone, { futures::executor::block_on(async { let mut server: T::Listener = server_transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let server_address = server.try_next() + let server_address = server + .try_next() .await .expect("some event") .expect("no error") @@ -204,7 +259,8 @@ where let outbound_msgs = messages.clone(); let client_fut = async { - let mut client_session = client_transport.dial(server_address.clone()) + let mut client_session = client_transport + .dial(server_address.clone()) .unwrap() .await .map(|(_, session)| session) @@ -219,7 +275,8 @@ where }; let server_fut = async { - let mut server_session = server.try_next() + let mut server_session = server + .try_next() .await .expect("some event") .map(ListenerEvent::into_upgrade) @@ -236,12 +293,15 @@ where match server_session.read_exact(&mut n).await { Ok(()) => u64::from_be_bytes(n), Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => 0, - Err(e) => panic!("error reading len: {}", e) + Err(e) => panic!("error reading len: {}", e), } }; info!("server: reading message ({} bytes)", len); let mut server_buffer = vec![0; len.try_into().unwrap()]; - server_session.read_exact(&mut server_buffer).await.expect("no error"); + server_session + .read_exact(&mut server_buffer) + .await + .expect("no error"); assert_eq!(server_buffer, m.0) } }; @@ -250,12 +310,13 @@ where }) } -fn expect_identity(output: Output, pk: &identity::PublicKey) - -> impl Future, NoiseError>> -{ +fn expect_identity( + output: Output, + pk: &identity::PublicKey, +) -> impl Future, NoiseError>> { match output.0 { RemoteIdentity::IdentityKey(ref k) if k == pk => future::ok(output), - _ => panic!("Unexpected remote identity") + _ => panic!("Unexpected remote identity"), } } diff --git a/transports/plaintext/build.rs b/transports/plaintext/build.rs index 1b0feff6a40..56c7b20121a 100644 --- a/transports/plaintext/build.rs +++ b/transports/plaintext/build.rs @@ -19,6 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/structs.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/structs.proto"], &["src"]).unwrap(); } - diff --git a/transports/plaintext/src/error.rs b/transports/plaintext/src/error.rs index 7ede99af60c..9f512c4f58e 100644 --- a/transports/plaintext/src/error.rs +++ b/transports/plaintext/src/error.rs @@ -47,16 +47,14 @@ impl error::Error for PlainTextError { impl fmt::Display for PlainTextError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - PlainTextError::IoError(e) => - write!(f, "I/O error: {}", e), - PlainTextError::InvalidPayload(protobuf_error) => { - match protobuf_error { - Some(e) => write!(f, "Protobuf error: {}", e), - None => f.write_str("Failed to parse one of the handshake protobuf messages") - } + PlainTextError::IoError(e) => write!(f, "I/O error: {}", e), + PlainTextError::InvalidPayload(protobuf_error) => match protobuf_error { + Some(e) => write!(f, "Protobuf error: {}", e), + None => f.write_str("Failed to parse one of the handshake protobuf messages"), }, - PlainTextError::InvalidPeerId => - f.write_str("The peer id of the exchange isn't consistent with the remote public key"), + PlainTextError::InvalidPeerId => f.write_str( + "The peer id of the exchange isn't consistent with the remote public key", + ), } } } diff --git a/transports/plaintext/src/handshake.rs b/transports/plaintext/src/handshake.rs index d981df4d964..6534c6d7abd 100644 --- a/transports/plaintext/src/handshake.rs +++ b/transports/plaintext/src/handshake.rs @@ -18,14 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::PlainText2Config; use crate::error::PlainTextError; use crate::structs_proto::Exchange; +use crate::PlainText2Config; +use asynchronous_codec::{Framed, FramedParts}; use bytes::{Bytes, BytesMut}; use futures::prelude::*; -use asynchronous_codec::{Framed, FramedParts}; -use libp2p_core::{PublicKey, PeerId}; +use libp2p_core::{PeerId, PublicKey}; use log::{debug, trace}; use prost::Message; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; @@ -33,7 +33,7 @@ use unsigned_varint::codec::UviBytes; struct HandshakeContext { config: PlainText2Config, - state: T + state: T, } // HandshakeContext<()> --with_local-> HandshakeContext @@ -54,28 +54,31 @@ impl HandshakeContext { fn new(config: PlainText2Config) -> Self { let exchange = Exchange { id: Some(config.local_public_key.to_peer_id().to_bytes()), - pubkey: Some(config.local_public_key.to_protobuf_encoding()) + pubkey: Some(config.local_public_key.to_protobuf_encoding()), }; let mut buf = Vec::with_capacity(exchange.encoded_len()); - exchange.encode(&mut buf).expect("Vec provides capacity as needed"); + exchange + .encode(&mut buf) + .expect("Vec provides capacity as needed"); Self { config, state: Local { - exchange_bytes: buf - } + exchange_bytes: buf, + }, } } - fn with_remote(self, exchange_bytes: BytesMut) - -> Result, PlainTextError> - { + fn with_remote( + self, + exchange_bytes: BytesMut, + ) -> Result, PlainTextError> { let prop = match Exchange::decode(exchange_bytes) { Ok(prop) => prop, Err(e) => { debug!("failed to parse remote's exchange protobuf message"); return Err(PlainTextError::InvalidPayload(Some(e))); - }, + } }; let pb_pubkey = prop.pubkey.unwrap_or_default(); @@ -84,20 +87,20 @@ impl HandshakeContext { Err(_) => { debug!("failed to parse remote's exchange's pubkey protobuf"); return Err(PlainTextError::InvalidPayload(None)); - }, + } }; let peer_id = match PeerId::from_bytes(&prop.id.unwrap_or_default()) { Ok(p) => p, Err(_) => { debug!("failed to parse remote's exchange's id protobuf"); return Err(PlainTextError::InvalidPayload(None)); - }, + } }; // Check the validity of the remote's `Exchange`. if peer_id != public_key.to_peer_id() { debug!("the remote's `PeerId` isn't consistent with the remote's public key"); - return Err(PlainTextError::InvalidPeerId) + return Err(PlainTextError::InvalidPeerId); } Ok(HandshakeContext { @@ -105,13 +108,15 @@ impl HandshakeContext { state: Remote { peer_id, public_key, - } + }, }) } } -pub async fn handshake(socket: S, config: PlainText2Config) - -> Result<(S, Remote, Bytes), PlainTextError> +pub async fn handshake( + socket: S, + config: PlainText2Config, +) -> Result<(S, Remote, Bytes), PlainTextError> where S: AsyncRead + AsyncWrite + Send + Unpin, { @@ -122,7 +127,9 @@ where let context = HandshakeContext::new(config); trace!("sending exchange to remote"); - framed_socket.send(BytesMut::from(&context.state.exchange_bytes[..])).await?; + framed_socket + .send(BytesMut::from(&context.state.exchange_bytes[..])) + .await?; trace!("receiving the remote's exchange"); let context = match framed_socket.next().await { @@ -134,9 +141,17 @@ where } }; - trace!("received exchange from remote; pubkey = {:?}", context.state.public_key); - - let FramedParts { io, read_buffer, write_buffer, .. } = framed_socket.into_parts(); + trace!( + "received exchange from remote; pubkey = {:?}", + context.state.public_key + ); + + let FramedParts { + io, + read_buffer, + write_buffer, + .. + } = framed_socket.into_parts(); assert!(write_buffer.is_empty()); Ok((io, context.state, read_buffer.freeze())) } diff --git a/transports/plaintext/src/lib.rs b/transports/plaintext/src/lib.rs index 0f3a4c9e585..1e9cfecf66f 100644 --- a/transports/plaintext/src/lib.rs +++ b/transports/plaintext/src/lib.rs @@ -21,19 +21,16 @@ use crate::error::PlainTextError; use bytes::Bytes; +use futures::future::BoxFuture; use futures::future::{self, Ready}; use futures::prelude::*; -use futures::future::BoxFuture; -use libp2p_core::{ - identity, - InboundUpgrade, - OutboundUpgrade, - UpgradeInfo, - PeerId, - PublicKey, -}; +use libp2p_core::{identity, InboundUpgrade, OutboundUpgrade, PeerId, PublicKey, UpgradeInfo}; use log::debug; -use std::{io, iter, pin::Pin, task::{Context, Poll}}; +use std::{ + io, iter, + pin::Pin, + task::{Context, Poll}, +}; use void::Void; mod error; @@ -42,7 +39,6 @@ mod structs_proto { include!(concat!(env!("OUT_DIR"), "/structs.rs")); } - /// `PlainText1Config` is an insecure connection handshake for testing purposes only. /// /// > **Note**: Given that `PlainText1Config` has no notion of exchanging peer identity information it is not compatible @@ -119,7 +115,7 @@ impl UpgradeInfo for PlainText2Config { impl InboundUpgrade for PlainText2Config where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = (PeerId, PlainTextOutput); type Error = PlainTextError; @@ -132,7 +128,7 @@ where impl OutboundUpgrade for PlainText2Config where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = (PeerId, PlainTextOutput); type Error = PlainTextError; @@ -146,7 +142,7 @@ where impl PlainText2Config { async fn handshake(self, socket: T) -> Result<(PeerId, PlainTextOutput), PlainTextError> where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { debug!("Starting plaintext handshake."); let (socket, remote, read_buffer) = handshake::handshake(socket, self).await?; @@ -158,7 +154,7 @@ impl PlainText2Config { socket, remote_key: remote.public_key, read_buffer, - } + }, )) } } @@ -179,35 +175,35 @@ where } impl AsyncRead for PlainTextOutput { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) - -> Poll> - { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { if !self.read_buffer.is_empty() { let n = std::cmp::min(buf.len(), self.read_buffer.len()); let b = self.read_buffer.split_to(n); buf[..n].copy_from_slice(&b[..]); - return Poll::Ready(Ok(n)) + return Poll::Ready(Ok(n)); } AsyncRead::poll_read(Pin::new(&mut self.socket), cx, buf) } } impl AsyncWrite for PlainTextOutput { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) - -> Poll> - { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf) } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll> - { + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { AsyncWrite::poll_flush(Pin::new(&mut self.socket), cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll> - { + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { AsyncWrite::poll_close(Pin::new(&mut self.socket), cx) } } diff --git a/transports/plaintext/tests/smoke.rs b/transports/plaintext/tests/smoke.rs index 20a79c32e1d..ce155bdd92e 100644 --- a/transports/plaintext/tests/smoke.rs +++ b/transports/plaintext/tests/smoke.rs @@ -18,12 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::io::{AsyncWriteExt, AsyncReadExt}; +use futures::io::{AsyncReadExt, AsyncWriteExt}; use futures::stream::TryStreamExt; use libp2p_core::{ identity, multiaddr::Multiaddr, - transport::{Transport, ListenerEvent}, + transport::{ListenerEvent, Transport}, upgrade, }; use libp2p_plaintext::PlainText2Config; @@ -45,38 +45,40 @@ fn variable_msg_length() { let client_id_public = client_id.public(); futures::executor::block_on(async { - let server_transport = libp2p_core::transport::MemoryTransport{}.and_then( - move |output, endpoint| { + let server_transport = + libp2p_core::transport::MemoryTransport {}.and_then(move |output, endpoint| { upgrade::apply( output, - PlainText2Config{local_public_key: server_id_public}, + PlainText2Config { + local_public_key: server_id_public, + }, endpoint, libp2p_core::upgrade::Version::V1, ) - } - ); + }); - let client_transport = libp2p_core::transport::MemoryTransport{}.and_then( - move |output, endpoint| { + let client_transport = + libp2p_core::transport::MemoryTransport {}.and_then(move |output, endpoint| { upgrade::apply( output, - PlainText2Config{local_public_key: client_id_public}, + PlainText2Config { + local_public_key: client_id_public, + }, endpoint, libp2p_core::upgrade::Version::V1, ) - } - ); + }); - - let server_address: Multiaddr = format!( - "/memory/{}", - std::cmp::Ord::max(1, rand::random::()) - ).parse().unwrap(); + let server_address: Multiaddr = + format!("/memory/{}", std::cmp::Ord::max(1, rand::random::())) + .parse() + .unwrap(); let mut server = server_transport.listen_on(server_address.clone()).unwrap(); // Ignore server listen address event. - let _ = server.try_next() + let _ = server + .try_next() .await .expect("some event") .expect("no error") @@ -85,17 +87,25 @@ fn variable_msg_length() { let client_fut = async { debug!("dialing {:?}", server_address); - let (received_server_id, mut client_channel) = client_transport.dial(server_address).unwrap().await.unwrap(); + let (received_server_id, mut client_channel) = client_transport + .dial(server_address) + .unwrap() + .await + .unwrap(); assert_eq!(received_server_id, server_id.public().to_peer_id()); debug!("Client: writing message."); - client_channel.write_all(&mut msg_to_send).await.expect("no error"); + client_channel + .write_all(&mut msg_to_send) + .await + .expect("no error"); debug!("Client: flushing channel."); client_channel.flush().await.expect("no error"); }; let server_fut = async { - let mut server_channel = server.try_next() + let mut server_channel = server + .try_next() .await .expect("some event") .map(ListenerEvent::into_upgrade) @@ -108,7 +118,10 @@ fn variable_msg_length() { let mut server_buffer = vec![0; msg_to_receive.len()]; debug!("Server: reading message."); - server_channel.read_exact(&mut server_buffer).await.expect("reading client message"); + server_channel + .read_exact(&mut server_buffer) + .await + .expect("reading client message"); assert_eq!(server_buffer, msg_to_receive); }; @@ -117,5 +130,7 @@ fn variable_msg_length() { }) } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec)) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec)) } diff --git a/transports/pnet/src/crypt_writer.rs b/transports/pnet/src/crypt_writer.rs index a61957d395d..e13bb446ce6 100644 --- a/transports/pnet/src/crypt_writer.rs +++ b/transports/pnet/src/crypt_writer.rs @@ -100,7 +100,9 @@ fn poll_flush_buf( if written > 0 { buf.drain(..written); } - if let Poll::Ready(Ok(())) = ret { debug_assert!(buf.is_empty()); } + if let Poll::Ready(Ok(())) = ret { + debug_assert!(buf.is_empty()); + } ret } diff --git a/transports/pnet/src/lib.rs b/transports/pnet/src/lib.rs index 468e82cd7c9..efd27b14667 100644 --- a/transports/pnet/src/lib.rs +++ b/transports/pnet/src/lib.rs @@ -74,7 +74,10 @@ impl PreSharedKey { cipher.apply_keystream(&mut enc); let mut hasher = Shake128::default(); hasher.write_all(&enc).expect("shake128 failed"); - hasher.finalize_xof().read_exact(&mut out).expect("shake128 failed"); + hasher + .finalize_xof() + .read_exact(&mut out) + .expect("shake128 failed"); Fingerprint(out) } } diff --git a/transports/tcp/src/lib.rs b/transports/tcp/src/lib.rs index 5cf4f0fcebd..e556bf39087 100644 --- a/transports/tcp/src/lib.rs +++ b/transports/tcp/src/lib.rs @@ -57,14 +57,14 @@ use socket2::{Domain, Socket, Type}; use std::{ collections::HashSet, io, - net::{SocketAddr, IpAddr, TcpListener}, + net::{IpAddr, SocketAddr, TcpListener}, pin::Pin, sync::{Arc, RwLock}, task::{Context, Poll}, time::Duration, }; -use provider::{Provider, IfEvent}; +use provider::{IfEvent, Provider}; /// The configuration for a TCP/IP transport capability for libp2p. /// @@ -101,7 +101,7 @@ enum PortReuse { Enabled { /// The addresses and ports of the listening sockets /// registered as eligible for port reuse when dialing. - listen_addrs: Arc>> + listen_addrs: Arc>>, }, } @@ -151,7 +151,7 @@ impl PortReuse { if ip.is_ipv4() == remote_ip.is_ipv4() && ip.is_loopback() == remote_ip.is_loopback() { - return Some(SocketAddr::new(*ip, *port)) + return Some(SocketAddr::new(*ip, *port)); } } } @@ -302,7 +302,7 @@ where pub fn port_reuse(mut self, port_reuse: bool) -> Self { self.port_reuse = if port_reuse { PortReuse::Enabled { - listen_addrs: Arc::new(RwLock::new(HashSet::new())) + listen_addrs: Arc::new(RwLock::new(HashSet::new())), } } else { PortReuse::Disabled @@ -385,8 +385,7 @@ where return Err(TransportError::MultiaddrNotSupported(addr)); }; log::debug!("listening on {}", socket_addr); - self.do_listen(socket_addr) - .map_err(TransportError::Other) + self.do_listen(socket_addr).map_err(TransportError::Other) } fn dial(self, addr: Multiaddr) -> Result> { @@ -439,19 +438,19 @@ enum InAddr { /// The stream accepts connections on a single interface. One { addr: IpAddr, - out: Option + out: Option, }, /// The stream accepts connections on all interfaces. Any { addrs: HashSet, if_watch: IfWatch, - } + }, } /// A stream of incoming connections on one or more interfaces. pub struct TcpListenStream where - T: Provider + T: Provider, { /// The socket address that the listening socket is bound to, /// which may be a "wildcard address" like `INADDR_ANY` or `IN6ADDR_ANY` @@ -481,7 +480,7 @@ where impl TcpListenStream where - T: Provider + T: Provider, { /// Constructs a `TcpListenStream` for incoming connections around /// the given `TcpListener`. @@ -527,7 +526,7 @@ where match &self.in_addr { InAddr::One { addr, .. } => { self.port_reuse.unregister(*addr, self.listen_addr.port()); - }, + } InAddr::Any { addrs, .. } => { for addr in addrs { self.port_reuse.unregister(*addr, self.listen_addr.port()); @@ -539,7 +538,7 @@ where impl Drop for TcpListenStream where - T: Provider + T: Provider, { fn drop(&mut self) { self.disable_port_reuse(); @@ -565,7 +564,7 @@ where IfWatch::Pending(f) => match ready!(Pin::new(f).poll(cx)) { Ok(w) => { *if_watch = IfWatch::Ready(w); - continue + continue; } Err(err) => { log::debug! { @@ -578,42 +577,52 @@ where } }, // Consume all events for up/down interface changes. - IfWatch::Ready(watch) => while let Poll::Ready(ev) = T::poll_interfaces(watch, cx) { - match ev { - Ok(IfEvent::Up(inet)) => { - let ip = inet.addr(); - if me.listen_addr.is_ipv4() == ip.is_ipv4() && addrs.insert(ip) { - let ma = ip_to_multiaddr(ip, me.listen_addr.port()); - log::debug!("New listen address: {}", ma); - me.port_reuse.register(ip, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(ma)))); + IfWatch::Ready(watch) => { + while let Poll::Ready(ev) = T::poll_interfaces(watch, cx) { + match ev { + Ok(IfEvent::Up(inet)) => { + let ip = inet.addr(); + if me.listen_addr.is_ipv4() == ip.is_ipv4() && addrs.insert(ip) + { + let ma = ip_to_multiaddr(ip, me.listen_addr.port()); + log::debug!("New listen address: {}", ma); + me.port_reuse.register(ip, me.listen_addr.port()); + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress( + ma, + )))); + } } - } - Ok(IfEvent::Down(inet)) => { - let ip = inet.addr(); - if me.listen_addr.is_ipv4() == ip.is_ipv4() && addrs.remove(&ip) { - let ma = ip_to_multiaddr(ip, me.listen_addr.port()); - log::debug!("Expired listen address: {}", ma); - me.port_reuse.unregister(ip, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(ma)))); + Ok(IfEvent::Down(inet)) => { + let ip = inet.addr(); + if me.listen_addr.is_ipv4() == ip.is_ipv4() && addrs.remove(&ip) + { + let ma = ip_to_multiaddr(ip, me.listen_addr.port()); + log::debug!("Expired listen address: {}", ma); + me.port_reuse.unregister(ip, me.listen_addr.port()); + return Poll::Ready(Some(Ok( + ListenerEvent::AddressExpired(ma), + ))); + } + } + Err(err) => { + log::debug! { + "Failure polling interfaces: {:?}. Scheduling retry.", + err + }; + me.pause = Some(Delay::new(me.sleep_on_error)); + return Poll::Ready(Some(Ok(ListenerEvent::Error(err)))); } - } - Err(err) => { - log::debug! { - "Failure polling interfaces: {:?}. Scheduling retry.", - err - }; - me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(ListenerEvent::Error(err)))); } } - }, + } }, // If the listener is bound to a single interface, make sure the // address is registered for port reuse and reported once. - InAddr::One { addr, out } => if let Some(multiaddr) = out.take() { - me.port_reuse.register(*addr, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(multiaddr)))) + InAddr::One { addr, out } => { + if let Some(multiaddr) = out.take() { + me.port_reuse.register(*addr, me.listen_addr.port()); + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(multiaddr)))); + } } } @@ -640,7 +649,8 @@ where }; let local_addr = ip_to_multiaddr(incoming.local_addr.ip(), incoming.local_addr.port()); - let remote_addr = ip_to_multiaddr(incoming.remote_addr.ip(), incoming.remote_addr.port()); + let remote_addr = + ip_to_multiaddr(incoming.remote_addr.ip(), incoming.remote_addr.port()); log::debug!("Incoming connection from {} at {}", remote_addr, local_addr); @@ -666,18 +676,18 @@ fn multiaddr_to_socketaddr(mut addr: Multiaddr) -> Result { match proto { Protocol::Ip4(ipv4) => match port { Some(port) => return Ok(SocketAddr::new(ipv4.into(), port)), - None => return Err(()) + None => return Err(()), }, Protocol::Ip6(ipv6) => match port { Some(port) => return Ok(SocketAddr::new(ipv6.into(), port)), - None => return Err(()) + None => return Err(()), }, Protocol::Tcp(portnum) => match port { Some(_) => return Err(()), - None => { port = Some(portnum) } - } + None => port = Some(portnum), + }, Protocol::P2p(_) => {} - _ => return Err(()) + _ => return Err(()), } } Err(()) @@ -685,15 +695,13 @@ fn multiaddr_to_socketaddr(mut addr: Multiaddr) -> Result { // Create a [`Multiaddr`] from the given IP address and port number. fn ip_to_multiaddr(ip: IpAddr, port: u16) -> Multiaddr { - Multiaddr::empty() - .with(ip.into()) - .with(Protocol::Tcp(port)) + Multiaddr::empty().with(ip.into()).with(Protocol::Tcp(port)) } #[cfg(test)] mod tests { - use futures::channel::mpsc; use super::*; + use futures::channel::mpsc; #[test] fn multiaddr_to_tcp_conversion() { @@ -748,7 +756,7 @@ mod tests { fn communicating_between_dialer_and_listener() { env_logger::try_init().ok(); - async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { + async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { let tcp = GenTcpConfig::::new(); let mut listener = tcp.listen_on(addr).unwrap(); loop { @@ -762,7 +770,7 @@ mod tests { upgrade.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, [1, 2, 3]); upgrade.write_all(&[4, 5, 6]).await.unwrap(); - return + return; } e => panic!("Unexpected listener event: {:?}", e), } @@ -798,7 +806,10 @@ mod tests { let (ready_tx, ready_rx) = mpsc::channel(1); let listener = listener::(addr.clone(), ready_tx); let dialer = dialer::(ready_rx); - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); let tasks = tokio_crate::task::LocalSet::new(); let listener = tasks.spawn_local(listener); tasks.block_on(&rt, dialer); @@ -833,7 +844,7 @@ mod tests { panic!("No TCP port in address: {}", a) } ready_tx.send(a).await.ok(); - return + return; } _ => {} } @@ -862,7 +873,10 @@ mod tests { let (ready_tx, ready_rx) = mpsc::channel(1); let listener = listener::(addr.clone(), ready_tx); let dialer = dialer::(ready_rx); - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); let tasks = tokio_crate::task::LocalSet::new(); let listener = tasks.spawn_local(listener); tasks.block_on(&rt, dialer); @@ -892,7 +906,7 @@ mod tests { upgrade.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, [1, 2, 3]); upgrade.write_all(&[4, 5, 6]).await.unwrap(); - return + return; } e => panic!("Unexpected event: {:?}", e), } @@ -913,7 +927,7 @@ mod tests { socket.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, [4, 5, 6]); } - e => panic!("Unexpected listener event: {:?}", e) + e => panic!("Unexpected listener event: {:?}", e), } } @@ -933,7 +947,10 @@ mod tests { let (ready_tx, ready_rx) = mpsc::channel(1); let listener = listener::(addr.clone(), ready_tx); let dialer = dialer::(addr.clone(), ready_rx); - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); let tasks = tokio_crate::task::LocalSet::new(); let listener = tasks.spawn_local(listener); tasks.block_on(&rt, dialer); @@ -959,7 +976,7 @@ mod tests { match listener2.next().await.unwrap().unwrap() { ListenerEvent::NewAddress(addr2) => { assert_eq!(addr1, addr2); - return + return; } e => panic!("Unexpected listener event: {:?}", e), } @@ -978,7 +995,10 @@ mod tests { #[cfg(feature = "tokio")] { let listener = listen_twice::(addr.clone()); - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); rt.block_on(listener); } } @@ -1011,7 +1031,10 @@ mod tests { #[cfg(feature = "tokio")] { - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); let new_addr = rt.block_on(listen::(addr.clone())); assert!(!new_addr.to_string().contains("tcp/0")); } diff --git a/transports/tcp/src/provider.rs b/transports/tcp/src/provider.rs index 091a6691087..7ebeaa49ee8 100644 --- a/transports/tcp/src/provider.rs +++ b/transports/tcp/src/provider.rs @@ -26,12 +26,12 @@ pub mod async_io; #[cfg(feature = "tokio")] pub mod tokio; -use futures::io::{AsyncRead, AsyncWrite}; use futures::future::BoxFuture; +use futures::io::{AsyncRead, AsyncWrite}; use ipnet::IpNet; +use std::net::{SocketAddr, TcpListener, TcpStream}; use std::task::{Context, Poll}; use std::{fmt, io}; -use std::net::{SocketAddr, TcpListener, TcpStream}; /// An event relating to a change of availability of an address /// on a network interface. @@ -73,7 +73,10 @@ pub trait Provider: Clone + Send + 'static { /// Polls a [`Self::Listener`] for an incoming connection, ensuring a task wakeup, /// if necessary. - fn poll_accept(_: &mut Self::Listener, _: &mut Context<'_>) -> Poll>>; + fn poll_accept( + _: &mut Self::Listener, + _: &mut Context<'_>, + ) -> Poll>>; /// Polls a [`Self::IfWatcher`] for network interface changes, ensuring a task wakeup, /// if necessary. diff --git a/transports/tcp/src/provider/async_io.rs b/transports/tcp/src/provider/async_io.rs index b4ce74d6901..ab65544d872 100644 --- a/transports/tcp/src/provider/async_io.rs +++ b/transports/tcp/src/provider/async_io.rs @@ -18,15 +18,13 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use super::{Provider, IfEvent, Incoming}; +use super::{IfEvent, Incoming, Provider}; use async_io_crate::Async; -use futures::{ - future::{BoxFuture, FutureExt}, -}; +use futures::future::{BoxFuture, FutureExt}; use std::io; -use std::task::{Poll, Context}; use std::net; +use std::task::{Context, Poll}; #[derive(Copy, Clone)] pub enum Tcp {} @@ -49,10 +47,14 @@ impl Provider for Tcp { let stream = Async::new(s)?; stream.writable().await?; Ok(stream) - }.boxed() + } + .boxed() } - fn poll_accept(l: &mut Self::Listener, cx: &mut Context<'_>) -> Poll>> { + fn poll_accept( + l: &mut Self::Listener, + cx: &mut Context<'_>, + ) -> Poll>> { let (stream, remote_addr) = loop { match l.poll_readable(cx) { Poll::Pending => return Poll::Pending, @@ -64,13 +66,17 @@ impl Provider for Tcp { // Since it doesn't do any harm, account for false positives of // `poll_readable` just in case, i.e. try again. } - } + }, } }; let local_addr = stream.get_ref().local_addr()?; - Poll::Ready(Ok(Incoming { stream, local_addr, remote_addr })) + Poll::Ready(Ok(Incoming { + stream, + local_addr, + remote_addr, + })) } fn poll_interfaces(w: &mut Self::IfWatcher, cx: &mut Context<'_>) -> Poll> { diff --git a/transports/tcp/src/provider/tokio.rs b/transports/tcp/src/provider/tokio.rs index 0e8136f2c60..257bccd2926 100644 --- a/transports/tcp/src/provider/tokio.rs +++ b/transports/tcp/src/provider/tokio.rs @@ -18,22 +18,22 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use super::{Provider, IfEvent, Incoming}; +use super::{IfEvent, Incoming, Provider}; use futures::{ future::{self, BoxFuture, FutureExt}, prelude::*, }; use futures_timer::Delay; -use if_addrs::{IfAddr, get_if_addrs}; +use if_addrs::{get_if_addrs, IfAddr}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use std::collections::HashSet; use std::convert::TryFrom; use std::io; -use std::task::{Poll, Context}; -use std::time::Duration; use std::net; use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; #[derive(Copy, Clone)] pub enum Tcp {} @@ -50,13 +50,12 @@ impl Provider for Tcp { type IfWatcher = IfWatcher; fn if_watcher() -> BoxFuture<'static, io::Result> { - future::ready(Ok( - IfWatcher { - addrs: HashSet::new(), - delay: Delay::new(Duration::from_secs(0)), - pending: Vec::new(), - } - )).boxed() + future::ready(Ok(IfWatcher { + addrs: HashSet::new(), + delay: Delay::new(Duration::from_secs(0)), + pending: Vec::new(), + })) + .boxed() } fn new_listener(l: net::TcpListener) -> io::Result { @@ -68,48 +67,59 @@ impl Provider for Tcp { let stream = tokio_crate::net::TcpStream::try_from(s)?; stream.writable().await?; Ok(TcpStream(stream)) - }.boxed() + } + .boxed() } - fn poll_accept(l: &mut Self::Listener, cx: &mut Context<'_>) - -> Poll>> - { + fn poll_accept( + l: &mut Self::Listener, + cx: &mut Context<'_>, + ) -> Poll>> { let (stream, remote_addr) = match l.poll_accept(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok((stream, remote_addr))) => (stream, remote_addr) + Poll::Ready(Ok((stream, remote_addr))) => (stream, remote_addr), }; let local_addr = stream.local_addr()?; let stream = TcpStream(stream); - Poll::Ready(Ok(Incoming { stream, local_addr, remote_addr })) + Poll::Ready(Ok(Incoming { + stream, + local_addr, + remote_addr, + })) } fn poll_interfaces(w: &mut Self::IfWatcher, cx: &mut Context<'_>) -> Poll> { loop { if let Some(event) = w.pending.pop() { - return Poll::Ready(Ok(event)) + return Poll::Ready(Ok(event)); } match Pin::new(&mut w.delay).poll(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(()) => { let ifs = get_if_addrs()?; - let addrs = ifs.into_iter().map(|iface| match iface.addr { - IfAddr::V4(ip4) => { - let prefix_len = (!u32::from_be_bytes(ip4.netmask.octets())).leading_zeros(); - let ipnet = Ipv4Net::new(ip4.ip, prefix_len as u8) - .expect("prefix_len can not exceed 32"); - IpNet::V4(ipnet) - } - IfAddr::V6(ip6) => { - let prefix_len = (!u128::from_be_bytes(ip6.netmask.octets())).leading_zeros(); - let ipnet = Ipv6Net::new(ip6.ip, prefix_len as u8) - .expect("prefix_len can not exceed 128"); - IpNet::V6(ipnet) - } - }).collect::>(); + let addrs = ifs + .into_iter() + .map(|iface| match iface.addr { + IfAddr::V4(ip4) => { + let prefix_len = + (!u32::from_be_bytes(ip4.netmask.octets())).leading_zeros(); + let ipnet = Ipv4Net::new(ip4.ip, prefix_len as u8) + .expect("prefix_len can not exceed 32"); + IpNet::V4(ipnet) + } + IfAddr::V6(ip6) => { + let prefix_len = + (!u128::from_be_bytes(ip6.netmask.octets())).leading_zeros(); + let ipnet = Ipv6Net::new(ip6.ip, prefix_len as u8) + .expect("prefix_len can not exceed 128"); + IpNet::V6(ipnet) + } + }) + .collect::>(); for down in w.addrs.difference(&addrs) { w.pending.push(IfEvent::Down(*down)); @@ -138,15 +148,27 @@ impl Into for TcpStream { } impl AsyncRead for TcpStream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { let mut read_buf = tokio_crate::io::ReadBuf::new(buf); - futures::ready!(tokio_crate::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, &mut read_buf))?; + futures::ready!(tokio_crate::io::AsyncRead::poll_read( + Pin::new(&mut self.0), + cx, + &mut read_buf + ))?; Poll::Ready(Ok(read_buf.filled().len())) } } impl AsyncWrite for TcpStream { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { tokio_crate::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) } @@ -161,7 +183,7 @@ impl AsyncWrite for TcpStream { fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>] + bufs: &[io::IoSlice<'_>], ) -> Poll> { tokio_crate::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs) } diff --git a/transports/uds/src/lib.rs b/transports/uds/src/lib.rs index 67da6c5fd85..34ac4eb51c3 100644 --- a/transports/uds/src/lib.rs +++ b/transports/uds/src/lib.rs @@ -34,12 +34,15 @@ #![cfg(all(unix, not(target_os = "emscripten")))] #![cfg_attr(docsrs, doc(cfg(all(unix, not(target_os = "emscripten")))))] -use futures::{prelude::*, future::{BoxFuture, Ready}}; use futures::stream::BoxStream; +use futures::{ + future::{BoxFuture, Ready}, + prelude::*, +}; use libp2p_core::{ + multiaddr::{Multiaddr, Protocol}, + transport::{ListenerEvent, TransportError}, Transport, - multiaddr::{Protocol, Multiaddr}, - transport::{ListenerEvent, TransportError} }; use log::debug; use std::{io, path::PathBuf}; @@ -145,14 +148,14 @@ fn multiaddr_to_path(addr: &Multiaddr) -> Result { Some(Protocol::Unix(ref path)) => { let path = PathBuf::from(path.as_ref()); if !path.is_absolute() { - return Err(()) + return Err(()); } match protocols.next() { None | Some(Protocol::P2p(_)) => Ok(path), - Some(_) => Err(()) + Some(_) => Err(()), } } - _ => Err(()) + _ => Err(()), } } @@ -160,15 +163,17 @@ fn multiaddr_to_path(addr: &Multiaddr) -> Result { mod tests { use super::{multiaddr_to_path, UdsConfig}; use futures::{channel::oneshot, prelude::*}; + use libp2p_core::{ + multiaddr::{Multiaddr, Protocol}, + Transport, + }; use std::{self, borrow::Cow, path::Path}; - use libp2p_core::{Transport, multiaddr::{Protocol, Multiaddr}}; use tempfile; #[test] fn multiaddr_to_path_conversion() { assert!( - multiaddr_to_path(&"/ip4/127.0.0.1/udp/1234".parse::().unwrap()) - .is_err() + multiaddr_to_path(&"/ip4/127.0.0.1/udp/1234".parse::().unwrap()).is_err() ); assert_eq!( @@ -185,21 +190,27 @@ mod tests { fn communicating_between_dialer_and_listener() { let temp_dir = tempfile::tempdir().unwrap(); let socket = temp_dir.path().join("socket"); - let addr = Multiaddr::from(Protocol::Unix(Cow::Owned(socket.to_string_lossy().into_owned()))); + let addr = Multiaddr::from(Protocol::Unix(Cow::Owned( + socket.to_string_lossy().into_owned(), + ))); let (tx, rx) = oneshot::channel(); async_std::task::spawn(async move { let mut listener = UdsConfig::new().listen_on(addr).unwrap(); - let listen_addr = listener.try_next().await.unwrap() + let listen_addr = listener + .try_next() + .await + .unwrap() .expect("some event") .into_new_address() .expect("listen address"); tx.send(listen_addr).unwrap(); - let (sock, _addr) = listener.try_filter_map(|e| future::ok(e.into_upgrade())) + let (sock, _addr) = listener + .try_filter_map(|e| future::ok(e.into_upgrade())) .try_next() .await .unwrap() @@ -220,18 +231,16 @@ mod tests { } #[test] - #[ignore] // TODO: for the moment unix addresses fail to parse + #[ignore] // TODO: for the moment unix addresses fail to parse fn larger_addr_denied() { let uds = UdsConfig::new(); - let addr = "/unix//foo/bar" - .parse::() - .unwrap(); + let addr = "/unix//foo/bar".parse::().unwrap(); assert!(uds.listen_on(addr).is_err()); } #[test] - #[ignore] // TODO: for the moment unix addresses fail to parse + #[ignore] // TODO: for the moment unix addresses fail to parse fn relative_addr_denied() { assert!("/unix/./foo/bar".parse::().is_err()); } diff --git a/transports/wasm-ext/src/lib.rs b/transports/wasm-ext/src/lib.rs index cec2ad1c1b9..27aafdb70c3 100644 --- a/transports/wasm-ext/src/lib.rs +++ b/transports/wasm-ext/src/lib.rs @@ -32,11 +32,11 @@ //! module. //! -use futures::{prelude::*, future::Ready}; +use futures::{future::Ready, prelude::*}; use libp2p_core::{transport::ListenerEvent, transport::TransportError, Multiaddr, Transport}; use parity_send_wrapper::SendWrapper; use std::{collections::VecDeque, error, fmt, io, mem, pin::Pin, task::Context, task::Poll}; -use wasm_bindgen::{JsCast, prelude::*}; +use wasm_bindgen::{prelude::*, JsCast}; use wasm_bindgen_futures::JsFuture; /// Contains the definition that one must match on the JavaScript side. @@ -172,16 +172,13 @@ impl Transport for ExtTransport { type Dial = Dial; fn listen_on(self, addr: Multiaddr) -> Result> { - let iter = self - .inner - .listen_on(&addr.to_string()) - .map_err(|err| { - if is_not_supported_error(&err) { - TransportError::MultiaddrNotSupported(addr) - } else { - TransportError::Other(JsErr::from(err)) - } - })?; + let iter = self.inner.listen_on(&addr.to_string()).map_err(|err| { + if is_not_supported_error(&err) { + TransportError::MultiaddrNotSupported(addr) + } else { + TransportError::Other(JsErr::from(err)) + } + })?; Ok(Listen { iterator: SendWrapper::new(iter), @@ -191,16 +188,13 @@ impl Transport for ExtTransport { } fn dial(self, addr: Multiaddr) -> Result> { - let promise = self - .inner - .dial(&addr.to_string()) - .map_err(|err| { - if is_not_supported_error(&err) { - TransportError::MultiaddrNotSupported(addr) - } else { - TransportError::Other(JsErr::from(err)) - } - })?; + let promise = self.inner.dial(&addr.to_string()).map_err(|err| { + if is_not_supported_error(&err) { + TransportError::MultiaddrNotSupported(addr) + } else { + TransportError::Other(JsErr::from(err)) + } + })?; Ok(Dial { inner: SendWrapper::new(promise.into()), @@ -315,7 +309,9 @@ impl Stream for Listen { .flat_map(|e| e.to_vec().into_iter()) { match js_value_to_addr(&addr) { - Ok(addr) => self.pending_events.push_back(ListenerEvent::NewAddress(addr)), + Ok(addr) => self + .pending_events + .push_back(ListenerEvent::NewAddress(addr)), Err(err) => self.pending_events.push_back(ListenerEvent::Error(err)), } } @@ -375,10 +371,16 @@ impl fmt::Debug for Connection { } impl AsyncRead for Connection { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { loop { match mem::replace(&mut self.read_state, ConnectionReadState::Finished) { - ConnectionReadState::Finished => break Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())), + ConnectionReadState::Finished => { + break Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } ConnectionReadState::PendingData(ref data) if data.is_empty() => { let iter_next = self.read_iterator.next().map_err(JsErr::from)?; @@ -411,7 +413,9 @@ impl AsyncRead for Connection { let data = match Future::poll(Pin::new(&mut *promise), cx) { Poll::Ready(Ok(ref data)) if data.is_null() => break Poll::Ready(Ok(0)), Poll::Ready(Ok(data)) => data, - Poll::Ready(Err(err)) => break Poll::Ready(Err(io::Error::from(JsErr::from(err)))), + Poll::Ready(Err(err)) => { + break Poll::Ready(Err(io::Error::from(JsErr::from(err)))) + } Poll::Pending => { self.read_state = ConnectionReadState::Waiting(promise); break Poll::Pending; @@ -439,14 +443,20 @@ impl AsyncRead for Connection { } impl AsyncWrite for Connection { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { // Note: as explained in the doc-comments of `Connection`, each call to this function must // map to exactly one call to `self.inner.write()`. if let Some(mut promise) = self.previous_write_promise.take() { match Future::poll(Pin::new(&mut *promise), cx) { Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(io::Error::from(JsErr::from(err)))), + Poll::Ready(Err(err)) => { + return Poll::Ready(Err(io::Error::from(JsErr::from(err)))) + } Poll::Pending => { self.previous_write_promise = Some(promise); return Poll::Pending; diff --git a/transports/websocket/src/error.rs b/transports/websocket/src/error.rs index 65a5d8350c0..47421d4c069 100644 --- a/transports/websocket/src/error.rs +++ b/transports/websocket/src/error.rs @@ -18,8 +18,8 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use libp2p_core::Multiaddr; use crate::tls; +use libp2p_core::Multiaddr; use std::{error, fmt}; /// Error in WebSockets. @@ -38,7 +38,7 @@ pub enum Error { /// The location header URL was invalid. InvalidRedirectLocation, /// Websocket base framing error. - Base(Box) + Base(Box), } impl fmt::Display for Error { @@ -50,7 +50,7 @@ impl fmt::Display for Error { Error::InvalidMultiaddr(ma) => write!(f, "invalid multi-address: {}", ma), Error::TooManyRedirects => f.write_str("too many redirects"), Error::InvalidRedirectLocation => f.write_str("invalid redirect location"), - Error::Base(err) => write!(f, "{}", err) + Error::Base(err) => write!(f, "{}", err), } } } @@ -64,7 +64,7 @@ impl error::Error for Error { Error::Base(err) => Some(&**err), Error::InvalidMultiaddr(_) | Error::TooManyRedirects - | Error::InvalidRedirectLocation => None + | Error::InvalidRedirectLocation => None, } } } diff --git a/transports/websocket/src/framed.rs b/transports/websocket/src/framed.rs index 204eddd836f..dc57cb8e220 100644 --- a/transports/websocket/src/framed.rs +++ b/transports/websocket/src/framed.rs @@ -18,15 +18,15 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures_rustls::{webpki, client, server}; use crate::{error::Error, tls}; use either::Either; use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream}; +use futures_rustls::{client, server, webpki}; use libp2p_core::{ - Transport, either::EitherOutput, - multiaddr::{Protocol, Multiaddr}, - transport::{ListenerEvent, TransportError} + multiaddr::{Multiaddr, Protocol}, + transport::{ListenerEvent, TransportError}, + Transport, }; use log::{debug, trace}; use soketto::{connection, extension::deflate::Deflate, handshake}; @@ -45,7 +45,7 @@ pub struct WsConfig { max_data_size: usize, tls_config: tls::Config, max_redirects: u8, - use_deflate: bool + use_deflate: bool, } impl WsConfig { @@ -56,7 +56,7 @@ impl WsConfig { max_data_size: MAX_DATA_SIZE, tls_config: tls::Config::client(), max_redirects: 0, - use_deflate: false + use_deflate: false, } } @@ -104,11 +104,12 @@ where T::Dial: Send + 'static, T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, - T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static + T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = Connection; type Error = Error; - type Listener = BoxStream<'static, Result, Self::Error>>; + type Listener = + BoxStream<'static, Result, Self::Error>>; type ListenerUpgrade = BoxFuture<'static, Result>; type Dial = BoxFuture<'static, Result>; @@ -116,24 +117,28 @@ where let mut inner_addr = addr.clone(); let (use_tls, proto) = match inner_addr.pop() { - Some(p@Protocol::Wss(_)) => + Some(p @ Protocol::Wss(_)) => { if self.tls_config.server.is_some() { (true, p) } else { debug!("/wss address but TLS server support is not configured"); - return Err(TransportError::MultiaddrNotSupported(addr)) + return Err(TransportError::MultiaddrNotSupported(addr)); } - Some(p@Protocol::Ws(_)) => (false, p), + } + Some(p @ Protocol::Ws(_)) => (false, p), _ => { debug!("{} is not a websocket multiaddr", addr); - return Err(TransportError::MultiaddrNotSupported(addr)) + return Err(TransportError::MultiaddrNotSupported(addr)); } }; let tls_config = self.tls_config; let max_size = self.max_data_size; let use_deflate = self.use_deflate; - let transport = self.transport.listen_on(inner_addr).map_err(|e| e.map(Error::Transport))?; + let transport = self + .transport + .listen_on(inner_addr) + .map_err(|e| e.map(Error::Transport))?; let listen = transport .map_err(Error::Transport) .map_ok(move |event| match event { @@ -146,10 +151,12 @@ where a = a.with(proto.clone()); ListenerEvent::AddressExpired(a) } - ListenerEvent::Error(err) => { - ListenerEvent::Error(Error::Transport(err)) - } - ListenerEvent::Upgrade { upgrade, mut local_addr, mut remote_addr } => { + ListenerEvent::Error(err) => ListenerEvent::Error(Error::Transport(err)), + ListenerEvent::Upgrade { + upgrade, + mut local_addr, + mut remote_addr, + } => { local_addr = local_addr.with(proto.clone()); remote_addr = remote_addr.with(proto.clone()); let remote1 = remote_addr.clone(); // used for logging @@ -160,28 +167,30 @@ where let stream = upgrade.map_err(Error::Transport).await?; trace!("incoming connection from {}", remote1); - let stream = - if use_tls { // begin TLS session - let server = tls_config - .server - .expect("for use_tls we checked server is not none"); + let stream = if use_tls { + // begin TLS session + let server = tls_config + .server + .expect("for use_tls we checked server is not none"); - trace!("awaiting TLS handshake with {}", remote1); + trace!("awaiting TLS handshake with {}", remote1); - let stream = server.accept(stream) - .map_err(move |e| { - debug!("TLS handshake with {} failed: {}", remote1, e); - Error::Tls(tls::Error::from(e)) - }) - .await?; + let stream = server + .accept(stream) + .map_err(move |e| { + debug!("TLS handshake with {} failed: {}", remote1, e); + Error::Tls(tls::Error::from(e)) + }) + .await?; - let stream: TlsOrPlain<_> = - EitherOutput::First(EitherOutput::Second(stream)); + let stream: TlsOrPlain<_> = + EitherOutput::First(EitherOutput::Second(stream)); - stream - } else { // continue with plain stream - EitherOutput::Second(stream) - }; + stream + } else { + // continue with plain stream + EitherOutput::Second(stream) + }; trace!("receiving websocket handshake request from {}", remote2); @@ -192,7 +201,8 @@ where } let ws_key = { - let request = server.receive_request() + let request = server + .receive_request() .map_err(|e| Error::Handshake(Box::new(e))) .await?; request.into_key() @@ -200,13 +210,13 @@ where trace!("accepting websocket handshake request from {}", remote2); - let response = - handshake::server::Response::Accept { - key: &ws_key, - protocol: None - }; + let response = handshake::server::Response::Accept { + key: &ws_key, + protocol: None, + }; - server.send_response(&response) + server + .send_response(&response) .map_err(|e| Error::Handshake(Box::new(e))) .await?; @@ -223,7 +233,7 @@ where ListenerEvent::Upgrade { upgrade: Box::pin(upgrade) as BoxFuture<'static, _>, local_addr, - remote_addr + remote_addr, } } }); @@ -233,7 +243,9 @@ where fn dial(self, addr: Multiaddr) -> Result> { let addr = match parse_ws_dial_addr(addr) { Ok(addr) => addr, - Err(Error::InvalidMultiaddr(a)) => return Err(TransportError::MultiaddrNotSupported(a)), + Err(Error::InvalidMultiaddr(a)) => { + return Err(TransportError::MultiaddrNotSupported(a)) + } Err(e) => return Err(TransportError::Other(e)), }; @@ -247,13 +259,13 @@ where Ok(Either::Left(redirect)) => { if remaining_redirects == 0 { debug!("Too many redirects (> {})", self.max_redirects); - return Err(Error::TooManyRedirects) + return Err(Error::TooManyRedirects); } remaining_redirects -= 1; addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)? } Ok(Either::Right(conn)) => return Ok(conn), - Err(e) => return Err(e) + Err(e) => return Err(e), } } }; @@ -269,37 +281,45 @@ where impl WsConfig where T: Transport, - T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static + T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static, { /// Attempts to dial the given address and perform a websocket handshake. - async fn dial_once(self, addr: WsAddress) -> Result>, Error> { + async fn dial_once( + self, + addr: WsAddress, + ) -> Result>, Error> { trace!("Dialing websocket address: {:?}", addr); - let dial = self.transport.dial(addr.tcp_addr) - .map_err(|e| match e { - TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a), - TransportError::Other(e) => Error::Transport(e) - })?; + let dial = self.transport.dial(addr.tcp_addr).map_err(|e| match e { + TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a), + TransportError::Other(e) => Error::Transport(e), + })?; let stream = dial.map_err(Error::Transport).await?; trace!("TCP connection to {} established.", addr.host_port); - let stream = - if addr.use_tls { // begin TLS session - let dns_name = addr.dns_name.expect("for use_tls we have checked that dns_name is some"); - trace!("Starting TLS handshake with {:?}", dns_name); - let stream = self.tls_config.client.connect(dns_name.as_ref(), stream) - .map_err(|e| { - debug!("TLS handshake with {:?} failed: {}", dns_name, e); - Error::Tls(tls::Error::from(e)) - }) - .await?; - - let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::First(stream)); - stream - } else { // continue with plain stream - EitherOutput::Second(stream) - }; + let stream = if addr.use_tls { + // begin TLS session + let dns_name = addr + .dns_name + .expect("for use_tls we have checked that dns_name is some"); + trace!("Starting TLS handshake with {:?}", dns_name); + let stream = self + .tls_config + .client + .connect(dns_name.as_ref(), stream) + .map_err(|e| { + debug!("TLS handshake with {:?} failed: {}", dns_name, e); + Error::Tls(tls::Error::from(e)) + }) + .await?; + + let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::First(stream)); + stream + } else { + // continue with plain stream + EitherOutput::Second(stream) + }; trace!("Sending websocket handshake to {}", addr.host_port); @@ -309,9 +329,19 @@ where client.add_extension(Box::new(Deflate::new(connection::Mode::Client))); } - match client.handshake().map_err(|e| Error::Handshake(Box::new(e))).await? { - handshake::ServerResponse::Redirect { status_code, location } => { - debug!("received redirect ({}); location: {}", status_code, location); + match client + .handshake() + .map_err(|e| Error::Handshake(Box::new(e))) + .await? + { + handshake::ServerResponse::Redirect { + status_code, + location, + } => { + debug!( + "received redirect ({}); location: {}", + status_code, location + ); Ok(Either::Left(location)) } handshake::ServerResponse::Rejected { status_code } => { @@ -349,20 +379,26 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { let mut tcp = protocols.next(); let (host_port, dns_name) = loop { match (ip, tcp) { - (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) - => break (format!("{}:{}", ip, port), None), - (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) - => break (format!("{}:{}", ip, port), None), - (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) | - (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) | - (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) | - (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) - => break (format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned())), + (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => { + break (format!("{}:{}", ip, port), None) + } + (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => { + break (format!("{}:{}", ip, port), None) + } + (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) + | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) + | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) + | (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => { + break ( + format!("{}:{}", &h, port), + Some(tls::dns_name_ref(&h)?.to_owned()), + ) + } (Some(_), Some(p)) => { ip = Some(p); tcp = protocols.next(); } - _ => return Err(Error::InvalidMultiaddr(addr)) + _ => return Err(Error::InvalidMultiaddr(addr)), } }; @@ -373,16 +409,16 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { let mut p2p = None; let (use_tls, path) = loop { match protocols.pop() { - p@Some(Protocol::P2p(_)) => { p2p = p } + p @ Some(Protocol::P2p(_)) => p2p = p, Some(Protocol::Ws(path)) => break (false, path.into_owned()), Some(Protocol::Wss(path)) => { if dns_name.is_none() { debug!("Missing DNS name in WSS address: {}", addr); - return Err(Error::InvalidMultiaddr(addr)) + return Err(Error::InvalidMultiaddr(addr)); } - break (true, path.into_owned()) + break (true, path.into_owned()); } - _ => return Err(Error::InvalidMultiaddr(addr)) + _ => return Err(Error::InvalidMultiaddr(addr)), } }; @@ -390,7 +426,7 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { // makes up the the address for the inner TCP-based transport. let tcp_addr = match p2p { Some(p) => protocols.with(p), - None => protocols + None => protocols, }; Ok(WsAddress { @@ -408,16 +444,10 @@ fn location_to_multiaddr(location: &str) -> Result> { Ok(url) => { let mut a = Multiaddr::empty(); match url.host() { - Some(url::Host::Domain(h)) => { - a.push(Protocol::Dns(h.into())) - } - Some(url::Host::Ipv4(ip)) => { - a.push(Protocol::Ip4(ip)) - } - Some(url::Host::Ipv6(ip)) => { - a.push(Protocol::Ip6(ip)) - } - None => return Err(Error::InvalidRedirectLocation) + Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())), + Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)), + Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)), + None => return Err(Error::InvalidRedirectLocation), } if let Some(p) = url.port() { a.push(Protocol::Tcp(p)) @@ -429,7 +459,7 @@ fn location_to_multiaddr(location: &str) -> Result> { a.push(Protocol::Ws(url.path().into())) } else { debug!("unsupported scheme: {}", s); - return Err(Error::InvalidRedirectLocation) + return Err(Error::InvalidRedirectLocation); } Ok(a) } @@ -444,7 +474,7 @@ fn location_to_multiaddr(location: &str) -> Result> { pub struct Connection { receiver: BoxStream<'static, Result>, sender: Pin + Send>>, - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } /// Data received over the websocket connection. @@ -455,7 +485,7 @@ pub enum IncomingData { /// UTF-8 encoded application data. Text(Vec), /// PONG control frame data. - Pong(Vec) + Pong(Vec), } impl IncomingData { @@ -464,22 +494,34 @@ impl IncomingData { } pub fn is_binary(&self) -> bool { - if let IncomingData::Binary(_) = self { true } else { false } + if let IncomingData::Binary(_) = self { + true + } else { + false + } } pub fn is_text(&self) -> bool { - if let IncomingData::Text(_) = self { true } else { false } + if let IncomingData::Text(_) = self { + true + } else { + false + } } pub fn is_pong(&self) -> bool { - if let IncomingData::Pong(_) = self { true } else { false } + if let IncomingData::Pong(_) = self { + true + } else { + false + } } pub fn into_bytes(self) -> Vec { match self { IncomingData::Binary(d) => d, IncomingData::Text(d) => d, - IncomingData::Pong(d) => d + IncomingData::Pong(d) => d, } } } @@ -489,7 +531,7 @@ impl AsRef<[u8]> for IncomingData { match self { IncomingData::Binary(d) => d, IncomingData::Text(d) => d, - IncomingData::Pong(d) => d + IncomingData::Pong(d) => d, } } } @@ -503,7 +545,7 @@ pub enum OutgoingData { Ping(Vec), /// Send an unsolicited PONG message. /// (Incoming PINGs are answered automatically.) - Pong(Vec) + Pong(Vec), } impl fmt::Debug for Connection { @@ -514,7 +556,7 @@ impl fmt::Debug for Connection { impl Connection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { fn new(builder: connection::Builder>) -> Self { let (sender, receiver) = builder.finish(); @@ -536,29 +578,31 @@ where sender.send_pong(data).await? } quicksink::Action::Flush => sender.flush().await?, - quicksink::Action::Close => sender.close().await? + quicksink::Action::Close => sender.close().await?, } Ok(sender) }); let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async { match receiver.receive(&mut data).await { - Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => { - Some((Ok(IncomingData::Text(mem::take(&mut data))), (data, receiver))) - } - Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => { - Some((Ok(IncomingData::Binary(mem::take(&mut data))), (data, receiver))) - } + Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some(( + Ok(IncomingData::Text(mem::take(&mut data))), + (data, receiver), + )), + Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some(( + Ok(IncomingData::Binary(mem::take(&mut data))), + (data, receiver), + )), Ok(soketto::Incoming::Pong(pong)) => { Some((Ok(IncomingData::Pong(Vec::from(pong))), (data, receiver))) } Err(connection::Error::Closed) => None, - Err(e) => Some((Err(e), (data, receiver))) + Err(e) => Some((Err(e), (data, receiver))), } }); Connection { receiver: stream.boxed(), sender: Box::pin(sink), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } @@ -580,22 +624,20 @@ where impl Stream for Connection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let item = ready!(self.receiver.poll_next_unpin(cx)); - let item = item.map(|result| { - result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - }); + let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e))); Poll::Ready(item) } } impl Sink for Connection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Error = io::Error; diff --git a/transports/websocket/src/lib.rs b/transports/websocket/src/lib.rs index 4473ed65d73..387aee7c72c 100644 --- a/transports/websocket/src/lib.rs +++ b/transports/websocket/src/lib.rs @@ -26,20 +26,26 @@ pub mod tls; use error::Error; use framed::Connection; -use futures::{future::BoxFuture, prelude::*, stream::BoxStream, ready}; +use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream}; use libp2p_core::{ - ConnectedPoint, - Transport, multiaddr::Multiaddr, - transport::{map::{MapFuture, MapStream}, ListenerEvent, TransportError} + transport::{ + map::{MapFuture, MapStream}, + ListenerEvent, TransportError, + }, + ConnectedPoint, Transport, }; use rw_stream_sink::RwStreamSink; -use std::{io, pin::Pin, task::{Context, Poll}}; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; /// A Websocket transport. #[derive(Debug, Clone)] pub struct WsConfig { - transport: framed::WsConfig + transport: framed::WsConfig, } impl WsConfig { @@ -92,9 +98,7 @@ impl WsConfig { impl From> for WsConfig { fn from(framed: framed::WsConfig) -> Self { - WsConfig { - transport: framed - } + WsConfig { transport: framed } } } @@ -105,7 +109,7 @@ where T::Dial: Send + 'static, T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, - T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static + T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = RwStreamSink>; type Error = Error; @@ -114,11 +118,15 @@ where type Dial = MapFuture, WrapperFn>; fn listen_on(self, addr: Multiaddr) -> Result> { - self.transport.map(wrap_connection as WrapperFn).listen_on(addr) + self.transport + .map(wrap_connection as WrapperFn) + .listen_on(addr) } fn dial(self, addr: Multiaddr) -> Result> { - self.transport.map(wrap_connection as WrapperFn).dial(addr) + self.transport + .map(wrap_connection as WrapperFn) + .dial(addr) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { @@ -127,7 +135,8 @@ where } /// Type alias corresponding to `framed::WsConfig::Listener`. -pub type InnerStream = BoxStream<'static, Result, Error>, Error>>; +pub type InnerStream = + BoxStream<'static, Result, Error>, Error>>; /// Type alias corresponding to `framed::WsConfig::Dial` and `framed::WsConfig::ListenerUpgrade`. pub type InnerFuture = BoxFuture<'static, Result, Error>>; @@ -139,7 +148,7 @@ pub type WrapperFn = fn(Connection, ConnectedPoint) -> RwStreamSink(c: Connection, _: ConnectedPoint) -> RwStreamSink> where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { RwStreamSink::new(BytesConnection(c)) } @@ -150,7 +159,7 @@ pub struct BytesConnection(Connection); impl Stream for BytesConnection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Item = io::Result>; @@ -158,10 +167,10 @@ where loop { if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) { if item.is_data() { - return Poll::Ready(Some(Ok(item.into_bytes()))) + return Poll::Ready(Some(Ok(item.into_bytes()))); } } else { - return Poll::Ready(None) + return Poll::Ready(None); } } } @@ -169,7 +178,7 @@ where impl Sink> for BytesConnection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Error = io::Error; @@ -194,10 +203,10 @@ where #[cfg(test)] mod tests { - use libp2p_core::{Multiaddr, PeerId, Transport, multiaddr::Protocol}; - use libp2p_tcp as tcp; - use futures::prelude::*; use super::WsConfig; + use futures::prelude::*; + use libp2p_core::{multiaddr::Protocol, Multiaddr, PeerId, Transport}; + use libp2p_tcp as tcp; #[test] fn dialer_connects_to_listener_ipv4() { @@ -214,11 +223,11 @@ mod tests { async fn connect(listen_addr: Multiaddr) { let ws_config = WsConfig::new(tcp::TcpConfig::new()); - let mut listener = ws_config.clone() - .listen_on(listen_addr) - .expect("listener"); + let mut listener = ws_config.clone().listen_on(listen_addr).expect("listener"); - let addr = listener.try_next().await + let addr = listener + .try_next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -228,7 +237,8 @@ mod tests { assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1)); let inbound = async move { - let (conn, _addr) = listener.try_filter_map(|e| future::ready(Ok(e.into_upgrade()))) + let (conn, _addr) = listener + .try_filter_map(|e| future::ready(Ok(e.into_upgrade()))) .try_next() .await .unwrap() @@ -236,7 +246,9 @@ mod tests { conn.await }; - let outbound = ws_config.dial(addr.with(Protocol::P2p(PeerId::random().into()))).unwrap(); + let outbound = ws_config + .dial(addr.with(Protocol::P2p(PeerId::random().into()))) + .unwrap(); let (a, b) = futures::join!(inbound, outbound); a.and(b).unwrap(); diff --git a/transports/websocket/src/tls.rs b/transports/websocket/src/tls.rs index d72535cdcc3..5aab39fe59b 100644 --- a/transports/websocket/src/tls.rs +++ b/transports/websocket/src/tls.rs @@ -18,14 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures_rustls::{rustls, webpki, TlsConnector, TlsAcceptor}; +use futures_rustls::{rustls, webpki, TlsAcceptor, TlsConnector}; use std::{fmt, io, sync::Arc}; /// TLS configuration. #[derive(Clone)] pub struct Config { pub(crate) client: TlsConnector, - pub(crate) server: Option + pub(crate) server: Option, } impl fmt::Debug for Config { @@ -60,7 +60,7 @@ impl Config { /// Create a new TLS configuration with the given server key and certificate chain. pub fn new(key: PrivateKey, certs: I) -> Result where - I: IntoIterator + I: IntoIterator, { let mut builder = Config::builder(); builder.server(key, certs)?; @@ -71,45 +71,55 @@ impl Config { pub fn client() -> Self { Config { client: Arc::new(client_config()).into(), - server: None + server: None, } } /// Create a new TLS configuration builder. pub fn builder() -> Builder { - Builder { client: client_config(), server: None } + Builder { + client: client_config(), + server: None, + } } } /// Setup the rustls client configuration. fn client_config() -> rustls::ClientConfig { let mut client = rustls::ClientConfig::new(); - client.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + client + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); client } /// TLS configuration builder. pub struct Builder { client: rustls::ClientConfig, - server: Option + server: Option, } impl Builder { /// Set server key and certificate chain. pub fn server(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error> where - I: IntoIterator + I: IntoIterator, { let mut server = rustls::ServerConfig::new(rustls::NoClientAuth::new()); let certs = certs.into_iter().map(|c| c.0).collect(); - server.set_single_cert(certs, key.0).map_err(|e| Error::Tls(Box::new(e)))?; + server + .set_single_cert(certs, key.0) + .map_err(|e| Error::Tls(Box::new(e)))?; self.server = Some(server); Ok(self) } /// Add an additional trust anchor. pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> { - self.client.root_store.add(&cert.0).map_err(|e| Error::Tls(Box::new(e)))?; + self.client + .root_store + .add(&cert.0) + .map_err(|e| Error::Tls(Box::new(e)))?; Ok(self) } @@ -117,7 +127,7 @@ impl Builder { pub fn finish(self) -> Config { Config { client: Arc::new(self.client).into(), - server: self.server.map(|s| Arc::new(s).into()) + server: self.server.map(|s| Arc::new(s).into()), } } } @@ -155,7 +165,7 @@ impl std::error::Error for Error { match self { Error::Io(e) => Some(e), Error::Tls(e) => Some(&**e), - Error::InvalidDnsName(_) => None + Error::InvalidDnsName(_) => None, } } }