Skip to content

Commit

Permalink
feat(db): Allow creating owned Postgres connections (#2654)
Browse files Browse the repository at this point in the history
## What ❔

- Changes `Connection` so that it has `'static` lifetime if created from
a pool (i.e., when it is non-transactional).
- Simplifies `ReadStorageFactory` and `MainBatchExecutor` accordingly.

## Why ❔

Reduces complexity. `'static` connections can be sent to a Tokio task
etc., meaning improved DevEx.

## Checklist

- [x] PR title corresponds to the body of PR (we generate changelog
entries from PRs).
- [x] Tests for the changes have been added / updated.
- [x] Documentation comments have been added / updated.
- [x] Code has been formatted via `zk fmt` and `zk lint`.
  • Loading branch information
slowli authored Aug 16, 2024
1 parent 1696e6e commit 47a082b
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 187 deletions.
35 changes: 20 additions & 15 deletions core/lib/db_connection/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{
collections::HashMap,
fmt, io,
marker::PhantomData,
panic::Location,
sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
Arc, Mutex, Weak,
},
time::{Instant, SystemTime},
};
Expand Down Expand Up @@ -98,14 +99,14 @@ impl TracedConnections {
}
}

struct PooledConnection<'a> {
struct PooledConnection {
connection: PoolConnection<Postgres>,
tags: Option<ConnectionTags>,
created_at: Instant,
traced: Option<(&'a TracedConnections, usize)>,
traced: (Weak<TracedConnections>, usize),
}

impl fmt::Debug for PooledConnection<'_> {
impl fmt::Debug for PooledConnection {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("PooledConnection")
Expand All @@ -115,7 +116,7 @@ impl fmt::Debug for PooledConnection<'_> {
}
}

impl Drop for PooledConnection<'_> {
impl Drop for PooledConnection {
fn drop(&mut self) {
if let Some(tags) = &self.tags {
let lifetime = self.created_at.elapsed();
Expand All @@ -132,15 +133,17 @@ impl Drop for PooledConnection<'_> {
);
}
}
if let Some((connections, id)) = self.traced {
connections.mark_as_dropped(id);

let (traced_connections, id) = &self.traced;
if let Some(connections) = traced_connections.upgrade() {
connections.mark_as_dropped(*id);
}
}
}

#[derive(Debug)]
enum ConnectionInner<'a> {
Pooled(PooledConnection<'a>),
Pooled(PooledConnection),
Transaction {
transaction: Transaction<'a, Postgres>,
tags: Option<&'a ConnectionTags>,
Expand All @@ -156,7 +159,7 @@ pub trait DbMarker: 'static + Send + Sync + Clone {}
#[derive(Debug)]
pub struct Connection<'a, DB: DbMarker> {
inner: ConnectionInner<'a>,
_marker: std::marker::PhantomData<DB>,
_marker: PhantomData<DB>,
}

impl<'a, DB: DbMarker> Connection<'a, DB> {
Expand All @@ -166,21 +169,23 @@ impl<'a, DB: DbMarker> Connection<'a, DB> {
pub(crate) fn from_pool(
connection: PoolConnection<Postgres>,
tags: Option<ConnectionTags>,
traced_connections: Option<&'a TracedConnections>,
traced_connections: Option<&Arc<TracedConnections>>,
) -> Self {
let created_at = Instant::now();
let inner = ConnectionInner::Pooled(PooledConnection {
connection,
tags,
created_at,
traced: traced_connections.map(|connections| {
traced: if let Some(connections) = traced_connections {
let id = connections.acquire(tags, created_at);
(connections, id)
}),
(Arc::downgrade(connections), id)
} else {
(Weak::new(), 0)
},
});
Self {
inner,
_marker: Default::default(),
_marker: PhantomData,
}
}

Expand All @@ -196,7 +201,7 @@ impl<'a, DB: DbMarker> Connection<'a, DB> {
};
Ok(Connection {
inner,
_marker: Default::default(),
_marker: PhantomData,
})
}

Expand Down
8 changes: 4 additions & 4 deletions core/lib/db_connection/src/connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ impl<DB: DbMarker> ConnectionPool<DB> {
///
/// This method is intended to be used in crucial contexts, where the
/// database access is must-have (e.g. block committer).
pub async fn connection(&self) -> DalResult<Connection<'_, DB>> {
pub async fn connection(&self) -> DalResult<Connection<'static, DB>> {
self.connection_inner(None).await
}

Expand All @@ -361,7 +361,7 @@ impl<DB: DbMarker> ConnectionPool<DB> {
pub fn connection_tagged(
&self,
requester: &'static str,
) -> impl Future<Output = DalResult<Connection<'_, DB>>> + '_ {
) -> impl Future<Output = DalResult<Connection<'static, DB>>> + '_ {
let location = Location::caller();
async move {
let tags = ConnectionTags {
Expand All @@ -375,7 +375,7 @@ impl<DB: DbMarker> ConnectionPool<DB> {
async fn connection_inner(
&self,
tags: Option<ConnectionTags>,
) -> DalResult<Connection<'_, DB>> {
) -> DalResult<Connection<'static, DB>> {
let acquire_latency = CONNECTION_METRICS.acquire.start();
let conn = self.acquire_connection_retried(tags.as_ref()).await?;
let elapsed = acquire_latency.observe();
Expand All @@ -386,7 +386,7 @@ impl<DB: DbMarker> ConnectionPool<DB> {
Ok(Connection::<DB>::from_pool(
conn,
tags,
self.traced_connections.as_deref(),
self.traced_connections.as_ref(),
))
}

Expand Down
3 changes: 1 addition & 2 deletions core/lib/state/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ pub use self::{
},
shadow_storage::ShadowStorage,
storage_factory::{
BatchDiff, OwnedPostgresStorage, OwnedStorage, PgOrRocksdbStorage, ReadStorageFactory,
RocksdbWithMemory,
BatchDiff, OwnedStorage, PgOrRocksdbStorage, ReadStorageFactory, RocksdbWithMemory,
},
};

Expand Down
79 changes: 23 additions & 56 deletions core/lib/state/src/storage_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use zksync_vm_interface::storage::ReadStorage;

use crate::{PostgresStorage, RocksdbStorage, RocksdbStorageBuilder, StateKeeperColumnFamily};

/// Storage with a static lifetime that can be sent to Tokio tasks etc.
pub type OwnedStorage = PgOrRocksdbStorage<'static>;

/// Factory that can produce storage instances on demand. The storage type is encapsulated as a type param
/// (mostly for testing purposes); the default is [`OwnedStorage`].
#[async_trait]
Expand All @@ -35,8 +38,9 @@ impl ReadStorageFactory for ConnectionPool<Core> {
_stop_receiver: &watch::Receiver<bool>,
l1_batch_number: L1BatchNumber,
) -> anyhow::Result<Option<OwnedStorage>> {
let storage = OwnedPostgresStorage::new(self.clone(), l1_batch_number);
Ok(Some(storage.into()))
let connection = self.connection().await?;
let storage = OwnedStorage::postgres(connection, l1_batch_number).await?;
Ok(Some(storage))
}
}

Expand All @@ -61,31 +65,29 @@ pub struct RocksdbWithMemory {
pub batch_diffs: Vec<BatchDiff>,
}

/// Owned Postgres-backed VM storage for a certain L1 batch.
/// A [`ReadStorage`] implementation that uses either [`PostgresStorage`] or [`RocksdbStorage`]
/// underneath.
#[derive(Debug)]
pub struct OwnedPostgresStorage {
connection_pool: ConnectionPool<Core>,
l1_batch_number: L1BatchNumber,
pub enum PgOrRocksdbStorage<'a> {
/// Implementation over a Postgres connection.
Postgres(PostgresStorage<'a>),
/// Implementation over a RocksDB cache instance.
Rocksdb(RocksdbStorage),
/// Implementation over a RocksDB cache instance with in-memory DB diffs.
RocksdbWithMemory(RocksdbWithMemory),
}

impl OwnedPostgresStorage {
/// Creates a VM storage for the specified batch number.
pub fn new(connection_pool: ConnectionPool<Core>, l1_batch_number: L1BatchNumber) -> Self {
Self {
connection_pool,
l1_batch_number,
}
}

/// Returns a [`ReadStorage`] implementation backed by Postgres
impl PgOrRocksdbStorage<'static> {
/// Creates a Postgres-based storage. Because of the `'static` lifetime requirement, `connection` must be
/// non-transactional.
///
/// # Errors
///
/// Propagates Postgres errors.
pub async fn borrow(&self) -> anyhow::Result<PgOrRocksdbStorage<'_>> {
let l1_batch_number = self.l1_batch_number;
let mut connection = self.connection_pool.connection().await?;

/// Propagates Postgres I/O errors.
pub async fn postgres(
mut connection: Connection<'static, Core>,
l1_batch_number: L1BatchNumber,
) -> anyhow::Result<Self> {
let l2_block_number = if let Some((_, l2_block_number)) = connection
.blocks_dal()
.get_l2_block_range_of_l1_batch(l1_batch_number)
Expand Down Expand Up @@ -114,42 +116,7 @@ impl OwnedPostgresStorage {
.into(),
)
}
}

/// Owned version of [`PgOrRocksdbStorage`]. It is thus possible to send to blocking tasks for VM execution.
#[derive(Debug)]
pub enum OwnedStorage {
/// Readily initialized storage with a static lifetime.
Static(PgOrRocksdbStorage<'static>),
/// Storage that must be `borrow()`ed from.
Lending(OwnedPostgresStorage),
}

impl From<OwnedPostgresStorage> for OwnedStorage {
fn from(storage: OwnedPostgresStorage) -> Self {
Self::Lending(storage)
}
}

impl From<PgOrRocksdbStorage<'static>> for OwnedStorage {
fn from(storage: PgOrRocksdbStorage<'static>) -> Self {
Self::Static(storage)
}
}

/// A [`ReadStorage`] implementation that uses either [`PostgresStorage`] or [`RocksdbStorage`]
/// underneath.
#[derive(Debug)]
pub enum PgOrRocksdbStorage<'a> {
/// Implementation over a Postgres connection.
Postgres(PostgresStorage<'a>),
/// Implementation over a RocksDB cache instance.
Rocksdb(RocksdbStorage),
/// Implementation over a RocksDB cache instance with in-memory DB diffs.
RocksdbWithMemory(RocksdbWithMemory),
}

impl PgOrRocksdbStorage<'static> {
/// Catches up RocksDB synchronously (i.e. assumes the gap is small) and
/// returns a [`ReadStorage`] implementation backed by caught-up RocksDB.
///
Expand Down
2 changes: 1 addition & 1 deletion core/node/api_server/src/web3/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ impl RpcState {
#[track_caller]
pub(crate) fn acquire_connection(
&self,
) -> impl Future<Output = Result<Connection<'_, Core>, Web3Error>> + '_ {
) -> impl Future<Output = Result<Connection<'static, Core>, Web3Error>> + '_ {
self.connection_pool
.connection_tagged("api")
.map_err(|err| err.generalize().into())
Expand Down
Loading

0 comments on commit 47a082b

Please sign in to comment.