Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RUST-1443 Ensure monitors close after server is removed from topology #744

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions src/client/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,19 @@ impl ClientOptions {
}
}

if let Some(heartbeat_frequency) = self.heartbeat_freq {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation existed, but it was in the URI parsing stage. I've moved it to ClientOptions::validate to ensure it gets hit even when the options are constructed directly.

if heartbeat_frequency < self.min_heartbeat_frequency() {
return Err(ErrorKind::InvalidArgument {
message: format!(
"'heartbeat_freq' must be at least {}ms, but {}ms was given",
self.min_heartbeat_frequency().as_millis(),
heartbeat_frequency.as_millis()
),
}
.into());
}
}

Ok(())
}

Expand Down Expand Up @@ -1373,6 +1386,21 @@ impl ClientOptions {
pub(crate) fn test_options_mut(&mut self) -> &mut TestOptions {
self.test_options.get_or_insert_with(Default::default)
}

pub(crate) fn min_heartbeat_frequency(&self) -> Duration {
#[cfg(test)]
{
self.test_options
.as_ref()
.and_then(|to| to.min_heartbeat_freq)
.unwrap_or(MIN_HEARTBEAT_FREQUENCY)
}

#[cfg(not(test))]
{
MIN_HEARTBEAT_FREQUENCY
}
}
}

/// Splits a string into a section before a given index and a section exclusively after the index.
Expand Down Expand Up @@ -1865,19 +1893,7 @@ impl ConnectionString {
self.direct_connection = Some(get_bool!(value, k));
}
k @ "heartbeatfrequencyms" => {
let duration = get_duration!(value, k);

if duration < MIN_HEARTBEAT_FREQUENCY.as_millis() as u64 {
return Err(ErrorKind::InvalidArgument {
message: format!(
"'heartbeatFrequencyMS' must be at least 500, but {} was given",
duration
),
}
.into());
}

self.heartbeat_frequency = Some(Duration::from_millis(duration));
self.heartbeat_frequency = Some(Duration::from_millis(get_duration!(value, k)));
}
k @ "journal" => {
let mut write_concern = self.write_concern.get_or_insert_with(Default::default);
Expand Down
14 changes: 14 additions & 0 deletions src/client/options/test.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::time::Duration;

use pretty_assertions::assert_eq;
use serde::Deserialize;

Expand All @@ -7,6 +9,7 @@ use crate::{
error::ErrorKind,
options::Compressor,
test::run_spec_test,
Client,
};
#[derive(Debug, Deserialize)]
struct TestFile {
Expand Down Expand Up @@ -282,3 +285,14 @@ async fn options_debug_omits_uri() {
assert!(!debug_output.contains("password"));
assert!(!debug_output.contains("uri"));
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn options_enforce_min_heartbeat_frequency() {
let options = ClientOptions::builder()
.hosts(vec![ServerAddress::parse("a:123").unwrap()])
.heartbeat_freq(Duration::from_millis(10))
.build();

Client::with_options(options).unwrap_err();
}
48 changes: 20 additions & 28 deletions src/sdam/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,11 @@ impl Monitor {
// We only go to sleep when using the polling protocol (i.e. server never returned a
// topologyVersion) or when the most recent check failed.
if self.topology_version.is_none() || !check_succeeded {
#[cfg(test)]
let min_frequency = self
.client_options
.test_options
.as_ref()
.and_then(|to| to.min_heartbeat_freq)
.unwrap_or(MIN_HEARTBEAT_FREQUENCY);

#[cfg(not(test))]
let min_frequency = MIN_HEARTBEAT_FREQUENCY;

tokio::select! {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the minimum delay into wait_for_check_request to consolidate the checks for the server closing.

_ = runtime::delay_for(min_frequency) => {},
_ = self.request_receiver.wait_for_server_close() => {
break;
}
}
self.request_receiver
.wait_for_check_request(heartbeat_frequency - min_frequency)
.wait_for_check_request(
self.client_options.min_heartbeat_frequency(),
heartbeat_frequency,
)
.await;
}
}
Expand Down Expand Up @@ -607,17 +593,28 @@ impl MonitorRequestReceiver {
err
}

/// Wait for a request to immediately check the server to come in, guarded by the provided
/// timeout. If a cancellation request is received indicating the topology has closed, this
/// method will return. All other cancellation requests will be ignored.
async fn wait_for_check_request(&mut self, timeout: Duration) {
/// Wait for a request to immediately check the server to be received, guarded by the provided
/// timeout. If the server associated with this monitor is removed from the topology, this
/// method will return.
///
/// The `delay` parameter indicates how long this method should wait before listening to
/// requests. The time spent in the delay counts toward the provided timeout.
async fn wait_for_check_request(&mut self, delay: Duration, timeout: Duration) {
let _ = runtime::timeout(timeout, async {
let wait_for_check_request = async {
runtime::delay_for(delay).await;
self.topology_check_request_receiver
.wait_for_check_request()
.await;
};
tokio::pin!(wait_for_check_request);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need the pin here since we're polling the same future repeatedly in a loop.


loop {
tokio::select! {
_ = self.individual_check_request_receiver.changed() => {
break;
}
_ = self.topology_check_request_receiver.wait_for_check_request() => {
_ = &mut wait_for_check_request => {
break;
}
_ = self.handle_listener.wait_for_all_handle_drops() => {
Expand All @@ -633,11 +630,6 @@ impl MonitorRequestReceiver {
self.cancellation_receiver.borrow_and_update();
}

/// Wait until the server associated with this monitor has been closed.
async fn wait_for_server_close(&mut self) {
self.handle_listener.wait_for_all_handle_drops().await;
}

fn is_alive(&self) -> bool {
self.handle_listener.is_alive()
}
Expand Down
93 changes: 93 additions & 0 deletions src/sdam/test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
collections::HashSet,
sync::Arc,
time::{Duration, Instant},
};
Expand All @@ -8,8 +9,12 @@ use semver::VersionReq;
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};

use crate::{
client::options::{ClientOptions, ServerAddress},
cmap::RawCommandResponse,
error::{Error, ErrorKind},
event::sdam::SdamEventHandler,
hello::{LEGACY_HELLO_COMMAND_NAME, LEGACY_HELLO_COMMAND_NAME_LOWERCASE},
sdam::{ServerDescription, Topology},
test::{
log_uncaptured,
CmapEvent,
Expand Down Expand Up @@ -269,3 +274,91 @@ async fn repl_set_name_mismatch() -> crate::error::Result<()> {

Ok(())
}

/// Test verifying that a server's monitor stops after the server has been removed from the
/// topology.
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn removed_server_monitor_stops() -> crate::error::Result<()> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test was copied pretty much as-is from #733

let _guard = LOCK.run_concurrently().await;

let handler = Arc::new(EventHandler::new());
let options = ClientOptions::builder()
.hosts(vec![
ServerAddress::parse("localhost:49152")?,
ServerAddress::parse("localhost:49153")?,
ServerAddress::parse("localhost:49154")?,
])
.heartbeat_freq(Duration::from_millis(50))
.sdam_event_handler(handler.clone() as Arc<dyn SdamEventHandler>)
.repl_set_name("foo".to_string())
.build();

let hosts = options.hosts.clone();
let set_name = options.repl_set_name.clone().unwrap();

let mut subscriber = handler.subscribe();
let topology = Topology::new(options)?;

// Wait until all three monitors have started.
let mut seen_monitors = HashSet::new();
subscriber
.wait_for_event(Duration::from_millis(500), |event| {
if let Event::Sdam(SdamEvent::ServerHeartbeatStarted(e)) = event {
seen_monitors.insert(e.server_address.clone());
}
seen_monitors.len() == hosts.len()
})
.await
.expect("should see all three monitors start");

// Remove the third host from the topology.
let hello = doc! {
"ok": 1,
"isWritablePrimary": true,
"hosts": [
hosts[0].clone().to_string(),
hosts[1].clone().to_string(),
],
"me": hosts[0].clone().to_string(),
"setName": set_name,
"maxBsonObjectSize": 1234,
"maxWriteBatchSize": 1234,
"maxMessageSizeBytes": 1234,
"minWireVersion": 0,
"maxWireVersion": 13,
};
let hello_reply = RawCommandResponse::with_document_and_address(hosts[0].clone(), hello)
.unwrap()
.into_hello_reply()
.unwrap();

topology
.clone_updater()
.update(ServerDescription::new_from_hello_reply(
hosts[0].clone(),
hello_reply,
Duration::from_millis(10),
))
.await;

subscriber.wait_for_event(Duration::from_secs(1), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerClosed(e)) if e.address == hosts[2])
}).await.expect("should see server closed event");

// Capture heartbeat events for 1 second. The monitor for the removed server should stop
// publishing them.
let events = subscriber.collect_events(Duration::from_secs(1), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerHeartbeatStarted(e)) if e.server_address == hosts[2])
}).await;

// Use 3 to account for any heartbeats that happen to start between emitting the ServerClosed
// event and actually publishing the state with the closed server.
assert!(
events.len() < 3,
"expected monitor for removed server to stop performing checks, but saw {} heartbeats",
events.len()
);

Ok(())
}
9 changes: 5 additions & 4 deletions src/sdam/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,11 @@ pub(crate) struct TopologyCheckRequestReceiver {

impl TopologyCheckRequestReceiver {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a minor change I made to mitigate the case where the value hasn't changed but there are still operations requesting an update. I don't believe this case actually could actually happen with the existing implementation, but I think it's best to make that clear here.

pub(crate) async fn wait_for_check_request(&mut self) {
while self.receiver.changed().await.is_ok() {
if *self.receiver.borrow() > 0 {
break;
}
while *self.receiver.borrow() == 0 {
// If all the requesters hung up, then just return early.
if self.receiver.changed().await.is_err() {
return;
};
}
}
}
Expand Down