Skip to content

Commit

Permalink
[ISSUE #1056]Optimize CountDownLatch⚡️ (#1057)
Browse files Browse the repository at this point in the history
  • Loading branch information
mxsm authored Oct 14, 2024
1 parent fd2f251 commit c201729
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions rocketmq/src/count_down_latch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
* limitations under the License.
*/
use std::sync::Arc;
use std::time::Duration;

use tokio::sync::Mutex;
use tokio::sync::Notify;

/// A synchronization aid that allows one or more tasks to wait until a set of operations being
/// performed in other tasks completes.
#[derive(Clone)]
pub struct CountDownLatch {
/// The current count of the latch.
count: Mutex<u32>,
count: Arc<Mutex<u32>>,
/// A notification mechanism to wake up waiting tasks.
notify: Notify,
notify: Arc<Notify>,
}

impl CountDownLatch {
Expand All @@ -38,22 +40,18 @@ impl CountDownLatch {
///
/// # Returns
///
/// An `Arc` to the newly created `CountDownLatch`.
pub fn new(count: u32) -> Arc<Self> {
Arc::new(CountDownLatch {
count: Mutex::new(count),
notify: Notify::new(),
})
/// A new `CountDownLatch`.
pub fn new(count: u32) -> Self {
CountDownLatch {
count: Arc::new(Mutex::new(count)),
notify: Arc::new(Notify::new()),
}
}

/// Decrements the count of the latch, releasing all waiting tasks if the count reaches zero.
///
/// This method is asynchronous and will lock the internal count before decrementing it.
///
/// # Arguments
///
/// * `self` - An `Arc` to the `CountDownLatch`.
pub async fn count_down(self: Arc<Self>) {
pub async fn count_down(&self) {
let mut count = self.count.lock().await;
*count -= 1;
if *count == 0 {
Expand All @@ -64,17 +62,29 @@ impl CountDownLatch {
/// Waits until the count reaches zero.
///
/// This method is asynchronous and will block the current task until the count reaches zero.
///
/// # Arguments
///
/// * `self` - An `Arc` to the `CountDownLatch`.
pub async fn wait(self: Arc<Self>) {
pub async fn wait(&self) {
let count = self.count.lock().await;
if *count > 0 {
drop(count);
self.notify.notified().await;
}
}

/// Waits until the count reaches zero or the specified timeout elapses.
///
/// This method is asynchronous and will block the current task until the count reaches zero
/// or the timeout elapses.
///
/// # Arguments
///
/// * `timeout` - The maximum duration to wait for the count to reach zero.
///
/// # Returns
///
/// `true` if the count reached zero before the timeout elapsed, `false` otherwise.
pub async fn wait_timeout(&self, timeout: Duration) -> bool {
tokio::time::timeout(timeout, self.wait()).await.is_ok()
}
}

#[cfg(test)]
Expand All @@ -88,6 +98,21 @@ mod tests {
assert_eq!(*count, 3);
}

#[tokio::test]
async fn wait_timeout_reaches_zero_before_timeout() {
let latch = CountDownLatch::new(1);
latch.count_down().await;
let result = latch.wait_timeout(Duration::from_secs(1)).await;
assert!(result);
}

#[tokio::test]
async fn wait_timeout_exceeds_timeout() {
let latch = CountDownLatch::new(1);
let result = latch.wait_timeout(Duration::from_millis(10)).await;
assert!(!result);
}

#[tokio::test]
async fn count_down_latch_count_down() {
let latch = CountDownLatch::new(3);
Expand Down

0 comments on commit c201729

Please sign in to comment.