From 531740550fc6443657740be9902b2ae088fed816 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Jul 2021 16:28:44 -0700 Subject: [PATCH] fix(pool): reimplement pool internals with `futures-intrusive` (#1320) --- Cargo.lock | 12 ++ sqlx-core/Cargo.toml | 1 + sqlx-core/src/pool/connection.rs | 42 ++-- sqlx-core/src/pool/inner.rs | 348 +++++++++++++------------------ sqlx-core/src/pool/mod.rs | 4 +- sqlx-core/src/pool/options.rs | 12 +- tests/postgres/postgres.rs | 15 +- 7 files changed, 201 insertions(+), 233 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4f165142eb..817f5ffdfe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -883,6 +883,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62007592ac46aa7c2b6416f7deb9a8a8f63a01e0f1d6e1787d5630170db2b63e" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.15" @@ -2356,6 +2367,7 @@ dependencies = [ "encoding_rs", "futures-channel", "futures-core", + "futures-intrusive", "futures-util", "generic-array", "git2", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 9d2ee65d8e..468a77a528 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -120,6 +120,7 @@ encoding_rs = { version = "0.8.23", optional = true } either = "1.5.3" futures-channel = { version = "0.3.5", default-features = false, features = ["sink", "alloc", "std"] } futures-core = { version = "0.3.5", default-features = false } +futures-intrusive = "0.4.0" futures-util = { version = "0.3.5", features = ["sink"] } generic-array = { version = "0.14.4", default-features = false, optional = true } hex = "0.4.2" diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 732c1a8c92..415fd0c1ec 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -1,13 +1,16 @@ -use super::inner::{DecrementSizeGuard, SharedPool}; -use crate::connection::Connection; -use crate::database::Database; -use crate::error::Error; -use sqlx_rt::spawn; use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::Instant; +use futures_intrusive::sync::SemaphoreReleaser; + +use crate::connection::Connection; +use crate::database::Database; +use crate::error::Error; + +use super::inner::{DecrementSizeGuard, SharedPool}; + /// A connection managed by a [`Pool`][crate::pool::Pool]. /// /// Will be returned to the pool on-drop. @@ -28,8 +31,8 @@ pub(super) struct Idle { /// RAII wrapper for connections being handled by functions that may drop them pub(super) struct Floating<'p, C> { - inner: C, - guard: DecrementSizeGuard<'p>, + pub(super) inner: C, + pub(super) guard: DecrementSizeGuard<'p>, } const DEREF_ERR: &str = "(bug) connection already released to pool"; @@ -71,7 +74,7 @@ impl Drop for PoolConnection { fn drop(&mut self) { if let Some(live) = self.live.take() { let pool = self.pool.clone(); - spawn(async move { + sqlx_rt::spawn(async move { let mut floating = live.float(&pool); // test the connection on-release to ensure it is still viable @@ -102,7 +105,8 @@ impl Live { pub fn float(self, pool: &SharedPool) -> Floating<'_, Self> { Floating { inner: self, - guard: DecrementSizeGuard::new(pool), + // create a new guard from a previously leaked permit + guard: DecrementSizeGuard::new_permit(pool), } } @@ -161,6 +165,11 @@ impl<'s, DB: Database> Floating<'s, Live> { } } + pub async fn close(self) -> Result<(), Error> { + // `guard` is dropped as intended + self.inner.raw.close().await + } + pub fn detach(self) -> DB::Connection { self.inner.raw } @@ -174,10 +183,14 @@ impl<'s, DB: Database> Floating<'s, Live> { } impl<'s, DB: Database> Floating<'s, Idle> { - pub fn from_idle(idle: Idle, pool: &'s SharedPool) -> Self { + pub fn from_idle( + idle: Idle, + pool: &'s SharedPool, + permit: SemaphoreReleaser<'s>, + ) -> Self { Self { inner: idle, - guard: DecrementSizeGuard::new(pool), + guard: DecrementSizeGuard::from_permit(pool, permit), } } @@ -192,9 +205,12 @@ impl<'s, DB: Database> Floating<'s, Idle> { } } - pub async fn close(self) -> Result<(), Error> { + pub async fn close(self) -> DecrementSizeGuard<'s> { // `guard` is dropped as intended - self.inner.live.raw.close().await + if let Err(e) = self.inner.live.raw.close().await { + log::debug!("error occurred while closing the pool connection: {}", e); + } + self.guard } } diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index f9e5df43b3..e297537e36 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -6,6 +6,7 @@ use crate::error::Error; use crate::pool::{deadline_as_timeout, PoolOptions}; use crossbeam_queue::{ArrayQueue, SegQueue}; use futures_core::task::{Poll, Waker}; +use futures_intrusive::sync::{Semaphore, SemaphoreReleaser}; use futures_util::future; use std::cmp; use std::mem; @@ -15,12 +16,16 @@ use std::sync::{Arc, Weak}; use std::task::Context; use std::time::{Duration, Instant}; -type Waiters = SegQueue>; +/// Ihe number of permits to release to wake all waiters, such as on `SharedPool::close()`. +/// +/// This should be large enough to realistically wake all tasks waiting on the pool without +/// potentially overflowing the permits count in the semaphore itself. +const WAKE_ALL_PERMITS: usize = usize::MAX / 2; pub(crate) struct SharedPool { pub(super) connect_options: ::Options, pub(super) idle_conns: ArrayQueue>, - waiters: Waiters, + pub(super) semaphore: Semaphore, pub(super) size: AtomicU32, is_closed: AtomicBool, pub(super) options: PoolOptions, @@ -31,10 +36,18 @@ impl SharedPool { options: PoolOptions, connect_options: ::Options, ) -> Arc { + let capacity = options.max_connections as usize; + + // ensure the permit count won't overflow if we release `WAKE_ALL_PERMITS` + // this assert should never fire on 64-bit targets as `max_connections` is a u32 + let _ = capacity + .checked_add(WAKE_ALL_PERMITS) + .expect("max_connections exceeds max capacity of the pool"); + let pool = Self { connect_options, - idle_conns: ArrayQueue::new(options.max_connections as usize), - waiters: SegQueue::new(), + idle_conns: ArrayQueue::new(capacity), + semaphore: Semaphore::new(options.fair, capacity), size: AtomicU32::new(0), is_closed: AtomicBool::new(false), options, @@ -61,148 +74,133 @@ impl SharedPool { } pub(super) async fn close(&self) { - self.is_closed.store(true, Ordering::Release); - while let Some(waker) = self.waiters.pop() { - if let Some(waker) = waker.upgrade() { - waker.wake(); - } + let already_closed = self.is_closed.swap(true, Ordering::AcqRel); + + if !already_closed { + // if we were the one to mark this closed, release enough permits to wake all waiters + // we can't just do `usize::MAX` because that would overflow + // and we can't do this more than once cause that would _also_ overflow + self.semaphore.release(WAKE_ALL_PERMITS); } - // ensure we wait until the pool is actually closed - while self.size() > 0 { - if let Some(idle) = self.idle_conns.pop() { - if let Err(e) = Floating::from_idle(idle, self).close().await { - log::warn!("error occurred while closing the pool connection: {}", e); - } - } + // wait for all permits to be released + let _permits = self + .semaphore + .acquire(WAKE_ALL_PERMITS + (self.options.max_connections as usize)) + .await; - // yield to avoid starving the executor - sqlx_rt::yield_now().await; + while let Some(idle) = self.idle_conns.pop() { + idle.live.float(self).close().await; } } #[inline] - pub(super) fn try_acquire(&self) -> Option>> { - // don't cut in line - if self.options.fair && !self.waiters.is_empty() { + pub(super) fn try_acquire(&self) -> Option>> { + if self.is_closed() { return None; } - Some(self.pop_idle()?.into_live()) + + let permit = self.semaphore.try_acquire(1)?; + self.pop_idle(permit).ok() } - fn pop_idle(&self) -> Option>> { - if self.is_closed.load(Ordering::Acquire) { - return None; + fn pop_idle<'a>( + &'a self, + permit: SemaphoreReleaser<'a>, + ) -> Result>, SemaphoreReleaser<'a>> { + if let Some(idle) = self.idle_conns.pop() { + Ok(Floating::from_idle(idle, self, permit)) + } else { + Err(permit) } - - Some(Floating::from_idle(self.idle_conns.pop()?, self)) } pub(super) fn release(&self, mut floating: Floating<'_, Live>) { if let Some(test) = &self.options.after_release { if !test(&mut floating.raw) { - // drop the connection and do not return to the pool + // drop the connection and do not return it to the pool return; } } - let is_ok = self - .idle_conns - .push(floating.into_idle().into_leakable()) - .is_ok(); + let Floating { inner: idle, guard } = floating.into_idle(); - if !is_ok { + if !self.idle_conns.push(idle).is_ok() { panic!("BUG: connection queue overflow in release()"); } - wake_one(&self.waiters); + // NOTE: we need to make sure we drop the permit *after* we push to the idle queue + // don't decrease the size + guard.release_permit(); } /// Try to atomically increment the pool size for a new connection. /// /// Returns `None` if we are at max_connections or if the pool is closed. - pub(super) fn try_increment_size(&self) -> Option> { - if self.is_closed() { - return None; + pub(super) fn try_increment_size<'a>( + &'a self, + permit: SemaphoreReleaser<'a>, + ) -> Result, SemaphoreReleaser<'a>> { + match self + .size + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { + size.checked_add(1) + .filter(|size| size <= &self.options.max_connections) + }) { + // we successfully incremented the size + Ok(_) => Ok(DecrementSizeGuard::from_permit(self, permit)), + // the pool is at max capacity + Err(_) => Err(permit), } - - let mut size = self.size(); - - while size < self.options.max_connections { - match self - .size - .compare_exchange(size, size + 1, Ordering::AcqRel, Ordering::Acquire) - { - Ok(_) => return Some(DecrementSizeGuard::new(self)), - Err(new_size) => size = new_size, - } - } - - None } #[allow(clippy::needless_lifetimes)] pub(super) async fn acquire<'s>(&'s self) -> Result>, Error> { - let start = Instant::now(); - let deadline = start + self.options.connect_timeout; - let mut waited = !self.options.fair; - - // the strong ref of the `Weak` that we push to the queue - // initialized during the `timeout()` call below - // as long as we own this, we keep our place in line - let mut waiter: Option> = None; - - // Unless the pool has been closed ... - while !self.is_closed() { - // Don't cut in line unless no one is waiting - if waited || self.waiters.is_empty() { - // Attempt to immediately acquire a connection. This will return Some - // if there is an idle connection in our channel. - if let Some(conn) = self.pop_idle() { - if let Some(live) = check_conn(conn, &self.options).await { - return Ok(live); - } - } + if self.is_closed() { + return Err(Error::PoolClosed); + } - // check if we can open a new connection - if let Some(guard) = self.try_increment_size() { - // pool has slots available; open a new connection - return self.connection(deadline, guard).await; - } - } + let deadline = Instant::now() + self.options.connect_timeout; - if let Some(ref waiter) = waiter { - // return the waiter to the queue, note that this does put it to the back - // of the queue when it should ideally stay at the front - self.waiters.push(Arc::downgrade(&waiter.inner)); - } + sqlx_rt::timeout( + self.options.connect_timeout, + async { + loop { + let permit = self.semaphore.acquire(1).await; - sqlx_rt::timeout( - // Returns an error if `deadline` passes - deadline_as_timeout::(deadline)?, - // `poll_fn` gets us easy access to a `Waker` that we can push to our queue - future::poll_fn(|cx| -> Poll<()> { - let waiter = waiter.get_or_insert_with(|| Waiter::push_new(cx, &self.waiters)); - - if waiter.is_woken() { - waiter.actually_woke = true; - Poll::Ready(()) - } else { - Poll::Pending + if self.is_closed() { + return Err(Error::PoolClosed); } - }), - ) - .await - .map_err(|_| Error::PoolTimedOut)?; - if let Some(ref mut waiter) = waiter { - waiter.reset(); + // First attempt to pop a connection from the idle queue. + let guard = match self.pop_idle(permit) { + + // Then, check that we can use it... + Ok(conn) => match check_conn(conn, &self.options).await { + + // All good! + Ok(live) => return Ok(live), + + // if the connection isn't usable for one reason or another, + // we get the `DecrementSizeGuard` back to open a new one + Err(guard) => guard, + }, + Err(permit) => if let Ok(guard) = self.try_increment_size(permit) { + // we can open a new connection + guard + } else { + log::debug!("woke but was unable to acquire idle connection or open new one; retrying"); + continue; + } + }; + + // Attempt to connect... + return self.connection(deadline, guard).await; + } } - - waited = true; - } - - Err(Error::PoolClosed) + ) + .await + .map_err(|_| Error::PoolTimedOut)? } pub(super) async fn connection<'s>( @@ -277,14 +275,13 @@ fn is_beyond_idle(idle: &Idle, options: &PoolOptions) -> b async fn check_conn<'s: 'p, 'p, DB: Database>( mut conn: Floating<'s, Idle>, options: &'p PoolOptions, -) -> Option>> { +) -> Result>, DecrementSizeGuard<'s>> { // If the connection we pulled has expired, close the connection and // immediately create a new connection if is_beyond_lifetime(&conn, options) { // we're closing the connection either way // close the connection but don't really care about the result - let _ = conn.close().await; - return None; + return Err(conn.close().await); } else if options.test_before_acquire { // Check that the connection is still live if let Err(e) = conn.ping().await { @@ -293,18 +290,18 @@ async fn check_conn<'s: 'p, 'p, DB: Database>( // the error itself here isn't necessarily unexpected so WARN is too strong log::info!("ping on idle connection returned error: {}", e); // connection is broken so don't try to close nicely - return None; + return Err(conn.close().await); } } else if let Some(test) = &options.before_acquire { match test(&mut conn.live.raw).await { Ok(false) => { // connection was rejected by user-defined hook - return None; + return Err(conn.close().await); } Err(error) => { log::info!("in `before_acquire`: {}", error); - return None; + return Err(conn.close().await); } Ok(true) => {} @@ -312,7 +309,7 @@ async fn check_conn<'s: 'p, 'p, DB: Database>( } // No need to re-connect; connection is alive or we don't care - Some(conn.into_live()) + Ok(conn.into_live()) } /// if `max_lifetime` or `idle_timeout` is set, spawn a task that reaps senescent connections @@ -325,18 +322,16 @@ fn spawn_reaper(pool: &Arc>) { (None, None) => return, }; - let pool = Arc::clone(&pool); - - sqlx_rt::spawn(async move { - while !pool.is_closed() { - // only reap idle connections when no tasks are waiting - if pool.waiters.is_empty() { - do_reap(&pool).await; - } - - sqlx_rt::sleep(period).await; - } - }); + // let pool = Arc::clone(&pool); + // + // sqlx_rt::spawn(async move { + // while !pool.is_closed() { + // if !pool.idle_conns.is_empty() { + // do_reap(&pool).await; + // } + // sqlx_rt::sleep(period).await; + // } + // }); } async fn do_reap(pool: &SharedPool) { @@ -346,7 +341,7 @@ async fn do_reap(pool: &SharedPool) { // collect connections to reap let (reap, keep) = (0..max_reaped) // only connections waiting in the queue - .filter_map(|_| pool.pop_idle()) + .filter_map(|_| pool.try_acquire()) .partition::, _>(|conn| { is_beyond_idle(conn, &pool.options) || is_beyond_lifetime(conn, &pool.options) }); @@ -361,38 +356,44 @@ async fn do_reap(pool: &SharedPool) { } } -fn wake_one(waiters: &Waiters) { - while let Some(weak) = waiters.pop() { - if let Some(waiter) = weak.upgrade() { - if waiter.wake() { - return; - } - } - } -} - /// RAII guard returned by `Pool::try_increment_size()` and others. /// /// Will decrement the pool size if dropped, to avoid semantically "leaking" connections /// (where the pool thinks it has more connections than it does). pub(in crate::pool) struct DecrementSizeGuard<'a> { size: &'a AtomicU32, - waiters: &'a Waiters, + semaphore: &'a Semaphore, dropped: bool, } impl<'a> DecrementSizeGuard<'a> { - pub fn new(pool: &'a SharedPool) -> Self { + /// Create a new guard that will release a semaphore permit on-drop. + pub fn new_permit(pool: &'a SharedPool) -> Self { Self { size: &pool.size, - waiters: &pool.waiters, + semaphore: &pool.semaphore, dropped: false, } } + pub fn from_permit( + pool: &'a SharedPool, + mut permit: SemaphoreReleaser<'a>, + ) -> Self { + // here we effectively take ownership of the permit + permit.disarm(); + Self::new_permit(pool) + } + /// Return `true` if the internal references point to the same fields in `SharedPool`. pub fn same_pool(&self, pool: &'a SharedPool) -> bool { - ptr::eq(self.size, &pool.size) && ptr::eq(self.waiters, &pool.waiters) + ptr::eq(self.size, &pool.size) + } + + /// Release the semaphore permit without decreasing the pool size. + fn release_permit(self) { + self.semaphore.release(1); + self.cancel(); } pub fn cancel(self) { @@ -405,73 +406,8 @@ impl Drop for DecrementSizeGuard<'_> { assert!(!self.dropped, "double-dropped!"); self.dropped = true; self.size.fetch_sub(1, Ordering::SeqCst); - wake_one(&self.waiters); - } -} - -struct WaiterInner { - woken: AtomicBool, - waker: Waker, -} - -impl WaiterInner { - /// Wake this waiter if it has not previously been woken. - /// - /// Return `true` if this waiter was newly woken, or `false` if it was already woken. - fn wake(&self) -> bool { - // if we were the thread to flip this boolean from false to true - if let Ok(_) = self - .woken - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) - { - self.waker.wake_by_ref(); - return true; - } - false - } -} - -struct Waiter<'a> { - inner: Arc, - queue: &'a Waiters, - actually_woke: bool, -} - -impl<'a> Waiter<'a> { - fn push_new(cx: &mut Context<'_>, queue: &'a Waiters) -> Self { - let inner = Arc::new(WaiterInner { - woken: AtomicBool::new(false), - waker: cx.waker().clone(), - }); - - queue.push(Arc::downgrade(&inner)); - - Self { - inner, - queue, - actually_woke: false, - } - } - - fn is_woken(&self) -> bool { - self.inner.woken.load(Ordering::Acquire) - } - - fn reset(&mut self) { - self.inner - .woken - .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire) - .ok(); - self.actually_woke = false; - } -} - -impl Drop for Waiter<'_> { - fn drop(&mut self) { - // if we didn't actually wake to get a connection, wake the next task instead - if self.is_woken() && !self.actually_woke { - wake_one(self.queue); - } + // and here we release the permit we got on construction + self.semaphore.release(1); } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 10b9c17335..12313e1b34 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -256,7 +256,9 @@ impl Pool { /// /// Returns `None` immediately if there are no idle connections available in the pool. pub fn try_acquire(&self) -> Option> { - self.0.try_acquire().map(|conn| conn.attach(&self.0)) + self.0 + .try_acquire() + .map(|conn| conn.into_live().attach(&self.0)) } /// Retrieves a new connection and immediately begins a new transaction. diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index a1b07f3721..32313808ff 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -231,19 +231,13 @@ impl PoolOptions { async fn init_min_connections(pool: &SharedPool) -> Result<(), Error> { for _ in 0..cmp::max(pool.options.min_connections, 1) { let deadline = Instant::now() + pool.options.connect_timeout; + let permit = pool.semaphore.acquire(1).await; // this guard will prevent us from exceeding `max_size` - if let Some(guard) = pool.try_increment_size() { + if let Ok(guard) = pool.try_increment_size(permit) { // [connect] will raise an error when past deadline let conn = pool.connection(deadline, guard).await?; - let is_ok = pool - .idle_conns - .push(conn.into_idle().into_leakable()) - .is_ok(); - - if !is_ok { - panic!("BUG: connection queue overflow in init_min_connections"); - } + pool.release(conn); } } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 590f06b5c5..5688aded5f 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -519,14 +519,19 @@ async fn pool_smoke_test() -> anyhow::Result<()> { for i in 0..200 { let pool = pool.clone(); sqlx_rt::spawn(async move { - loop { + for j in 0.. { if let Err(e) = sqlx::query("select 1 + 1").execute(&pool).await { // normal error at termination of the test - if !matches!(e, sqlx::Error::PoolClosed) { - eprintln!("pool task {} dying due to {}", i, e); - break; + if matches!(e, sqlx::Error::PoolClosed) { + eprintln!("pool task {} exiting normally after {} iterations", i, j); + } else { + eprintln!("pool task {} dying due to {} after {} iterations", i, e, j); } + break; } + + // shouldn't be necessary if the pool is fair + // sqlx_rt::yield_now().await; } }); } @@ -547,6 +552,8 @@ async fn pool_smoke_test() -> anyhow::Result<()> { }) .await; + // this one is necessary since this is a hot loop, + // otherwise this task will never be descheduled sqlx_rt::yield_now().await; } });