diff --git a/openraft/src/core/mod.rs b/openraft/src/core/mod.rs index f95aacfb0..ce9ffab2c 100644 --- a/openraft/src/core/mod.rs +++ b/openraft/src/core/mod.rs @@ -195,7 +195,7 @@ pub struct RaftCore, S: RaftStorage< tx_compaction: mpsc::Sender>, rx_compaction: mpsc::Receiver>, - rx_api: mpsc::UnboundedReceiver<(RaftMsg, Span)>, + rx_api: mpsc::UnboundedReceiver<(RaftMsg, Span)>, tx_metrics: watch::Sender>, @@ -208,7 +208,7 @@ impl, S: RaftStorage> RaftCore, network: N, storage: S, - rx_api: mpsc::UnboundedReceiver<(RaftMsg, Span)>, + rx_api: mpsc::UnboundedReceiver<(RaftMsg, Span)>, tx_metrics: watch::Sender>, rx_shutdown: oneshot::Receiver<()>, ) -> JoinHandle>> { @@ -849,7 +849,7 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory, S: RaftStorage> LeaderS } #[tracing::instrument(level = "debug", skip(self, msg), fields(state = "leader", id=display(self.core.id)))] - pub async fn handle_msg(&mut self, msg: RaftMsg) -> Result<(), Fatal> { + pub async fn handle_msg(&mut self, msg: RaftMsg) -> Result<(), Fatal> { tracing::debug!("recv from rx_api: {}", msg.summary()); match msg { @@ -885,6 +885,9 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory, S: RaftStorage> LeaderS } => { self.change_membership(members, blocking, turn_to_learner, tx).await?; } + RaftMsg::ExternalRequest { req } => { + req(State::Leader, &mut self.core.storage, &mut self.core.network); + } }; Ok(()) @@ -1016,7 +1019,7 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory, S: RaftStorage> Candida } #[tracing::instrument(level = "debug", skip(self, msg), fields(state = "candidate", id=display(self.core.id)))] - pub async fn handle_msg(&mut self, msg: RaftMsg) -> Result<(), Fatal> { + pub async fn handle_msg(&mut self, msg: RaftMsg) -> Result<(), Fatal> { tracing::debug!("recv from rx_api: {}", msg.summary()); match msg { RaftMsg::AppendEntries { rpc, tx } => { @@ -1043,6 +1046,9 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory, S: RaftStorage> Candida RaftMsg::ChangeMembership { tx, .. } => { self.core.reject_with_forward_to_leader(tx); } + RaftMsg::ExternalRequest { req } => { + req(State::Candidate, &mut self.core.storage, &mut self.core.network); + } }; Ok(()) } @@ -1091,7 +1097,7 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory, S: RaftStorage> Followe } #[tracing::instrument(level = "debug", skip(self, msg), fields(state = "follower", id=display(self.core.id)))] - pub(crate) async fn handle_msg(&mut self, msg: RaftMsg) -> Result<(), Fatal> { + pub(crate) async fn handle_msg(&mut self, msg: RaftMsg) -> Result<(), Fatal> { tracing::debug!("recv from rx_api: {}", msg.summary()); match msg { @@ -1119,6 +1125,9 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory, S: RaftStorage> Followe RaftMsg::ChangeMembership { tx, .. } => { self.core.reject_with_forward_to_leader(tx); } + RaftMsg::ExternalRequest { req } => { + req(State::Follower, &mut self.core.storage, &mut self.core.network); + } }; Ok(()) } @@ -1165,7 +1174,7 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory, S: RaftStorage> Learner // TODO(xp): define a handle_msg method in RaftCore that decides what to do by current State. #[tracing::instrument(level = "debug", skip(self, msg), fields(state = "learner", id=display(self.core.id)))] - pub(crate) async fn handle_msg(&mut self, msg: RaftMsg) -> Result<(), Fatal> { + pub(crate) async fn handle_msg(&mut self, msg: RaftMsg) -> Result<(), Fatal> { tracing::debug!("recv from rx_api: {}", msg.summary()); match msg { @@ -1193,6 +1202,9 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory, S: RaftStorage> Learner RaftMsg::ChangeMembership { tx, .. } => { self.core.reject_with_forward_to_leader(tx); } + RaftMsg::ExternalRequest { req } => { + req(State::Learner, &mut self.core.storage, &mut self.core.network); + } }; Ok(()) } diff --git a/openraft/src/raft.rs b/openraft/src/raft.rs index c91bae6a8..945565f50 100644 --- a/openraft/src/raft.rs +++ b/openraft/src/raft.rs @@ -38,6 +38,7 @@ use crate::NodeId; use crate::RaftNetworkFactory; use crate::RaftStorage; use crate::SnapshotMeta; +use crate::State; use crate::Vote; /// Configuration of types used by the [`Raft`] core engine. @@ -104,7 +105,7 @@ macro_rules! declare_raft_types { } struct RaftInner, S: RaftStorage> { - tx_api: mpsc::UnboundedSender<(RaftMsg, Span)>, + tx_api: mpsc::UnboundedSender<(RaftMsg, Span)>, rx_metrics: watch::Receiver>, #[allow(clippy::type_complexity)] raft_handle: Mutex>>>>, @@ -407,7 +408,7 @@ impl, S: RaftStorage> Raft(&self, mes: RaftMsg, rx: RaftRespRx) -> Result + pub(crate) async fn call_core(&self, mes: RaftMsg, rx: RaftRespRx) -> Result where E: From> { let span = tracing::Span::current(); @@ -451,6 +452,25 @@ impl, S: RaftStorage> Raft(&self, req: F) { + let _ignore_error = self.inner.tx_api.send(( + RaftMsg::ExternalRequest { req: Box::new(req) }, + tracing::span::Span::none(), // fire-and-forget, so no span + )); + } + /// Get a handle to the metrics channel. pub fn metrics(&self) -> watch::Receiver> { self.inner.rx_metrics.clone() @@ -513,7 +533,7 @@ pub struct AddLearnerResponse { } /// A message coming from the Raft API. -pub(crate) enum RaftMsg { +pub(crate) enum RaftMsg, S: RaftStorage> { AppendEntries { rpc: AppendEntriesRequest, tx: RaftRespTx, AppendEntriesError>, @@ -564,10 +584,16 @@ pub(crate) enum RaftMsg { tx: RaftRespTx, ClientWriteError>, }, + ExternalRequest { + req: Box, + }, } -impl MessageSummary for RaftMsg -where C: RaftTypeConfig +impl MessageSummary for RaftMsg +where + C: RaftTypeConfig, + N: RaftNetworkFactory, + S: RaftStorage, { fn summary(&self) -> String { match self { @@ -601,6 +627,7 @@ where C: RaftTypeConfig members, blocking, turn_to_learner, ) } + RaftMsg::ExternalRequest { .. } => "External Request".to_string(), } } } diff --git a/openraft/tests/fixtures/mod.rs b/openraft/tests/fixtures/mod.rs index 657ff94aa..5faa73191 100644 --- a/openraft/tests/fixtures/mod.rs +++ b/openraft/tests/fixtures/mod.rs @@ -577,6 +577,19 @@ where } } + /// Send external request to the particular node. + pub fn external_request, &mut TypedRaftRouter) + Send + 'static>( + &self, + target: C::NodeId, + req: F, + ) { + let rt = self.routing_table.lock().unwrap(); + rt.get(&target) + .unwrap_or_else(|| panic!("node '{}' does not exist in routing table", target)) + .0 + .external_request(req) + } + /// Request the current leader from the target node. pub async fn current_leader(&self, target: C::NodeId) -> Option { let node = self.get_raft_handle(&target).unwrap(); diff --git a/openraft/tests/initialization.rs b/openraft/tests/initialization.rs index c44bbba94..12f56c3a3 100644 --- a/openraft/tests/initialization.rs +++ b/openraft/tests/initialization.rs @@ -14,6 +14,7 @@ use openraft::Membership; use openraft::RaftLogReader; use openraft::RaftStorage; use openraft::State; +use tokio::sync::oneshot; #[macro_use] mod fixtures; @@ -47,6 +48,23 @@ async fn initialization() -> Result<()> { router.wait_for_state(&btreeset![0, 1, 2], State::Learner, timeout(), "empty").await?; router.assert_pristine_cluster().await; + // Sending an external requests will also find all nodes in Learner state. + // + // This demonstrates fire-and-forget external request, which will be serialized + // with other processing. It is not required for the correctness of the test + // + // Since the execution of API messages is serialized, even if the request executes + // some unknown time in the future (due to fire-and-forget semantics), it will + // properly receive the state before initialization, as that state will appear + // later in the sequence. + // + // Also, this external request will be definitely executed, since it's ordered + // before other requests in the Raft core API queue, which definitely are executed + // (since they are awaited). + for node in [0, 1, 2] { + router.external_request(node, |s, _sm, _net| assert_eq!(s, State::Learner)); + } + // Initialize the cluster, then assert that a stable cluster was formed & held. tracing::info!("--- initializing cluster"); router.initialize_from_single_node(0).await?; @@ -78,6 +96,31 @@ async fn initialization() -> Result<()> { ); } + // At this time, one of the nodes is the leader, all the others are followers. + // Check via an external request as well. Again, this is not required for the + // correctness of the test. + // + // This demonstrates how to synchronize on the execution of the external + // request by using a oneshot channel. + let mut found_leader = false; + let mut follower_count = 0; + for node in [0, 1, 2] { + let (tx, rx) = oneshot::channel(); + router.external_request(node, |s, _sm, _net| tx.send(s).unwrap()); + match rx.await.unwrap() { + State::Leader => { + assert!(!found_leader); + found_leader = true; + } + State::Follower => { + follower_count += 1; + } + s => panic!("Unexpected node {} state: {:?}", node, s), + } + } + assert!(found_leader); + assert_eq!(2, follower_count); + Ok(()) }