Skip to content

Commit

Permalink
Refactor WriteTransaction to use TransactionGuard
Browse files Browse the repository at this point in the history
  • Loading branch information
cberner committed Dec 21, 2023
1 parent 213e675 commit 10ef672
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 154 deletions.
135 changes: 54 additions & 81 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ use std::io::ErrorKind;
use std::marker::PhantomData;
use std::ops::RangeFull;
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::sync::{Arc, Mutex};

use crate::error::TransactionError;
use crate::multimap_table::{parse_subtree_roots, DynamicCollection};
Expand Down Expand Up @@ -56,23 +55,6 @@ pub trait StorageBackend: 'static + Debug + Send + Sync {
fn write(&self, offset: u64, data: &[u8]) -> std::result::Result<(), io::Error>;
}

struct AtomicTransactionId {
inner: AtomicU64,
}

impl AtomicTransactionId {
fn new(last_id: TransactionId) -> Self {
Self {
inner: AtomicU64::new(last_id.raw_id()),
}
}

fn next(&self) -> TransactionId {
let id = self.inner.fetch_add(1, Ordering::AcqRel);
TransactionId::new(id)
}
}

pub trait TableHandle: Sealed {
// Returns the name of the table
fn name(&self) -> &str;
Expand Down Expand Up @@ -233,11 +215,38 @@ impl<'a, K: RedbKey + 'static, V: RedbKey + 'static> Display for MultimapTableDe
}

pub(crate) struct TransactionGuard {
transaction_tracker: Arc<Mutex<TransactionTracker>>,
transaction_tracker: Arc<TransactionTracker>,
transaction_id: Option<TransactionId>,
write_transaction: bool,
}

impl TransactionGuard {
pub(crate) fn new_read(
transaction_id: TransactionId,
tracker: Arc<TransactionTracker>,
) -> Self {
Self {
transaction_tracker: tracker,
transaction_id: Some(transaction_id),
write_transaction: false,
}
}

pub(crate) fn new_write(
transaction_id: TransactionId,
tracker: Arc<TransactionTracker>,
) -> Self {
Self {
transaction_tracker: tracker,
transaction_id: Some(transaction_id),
write_transaction: true,
}
}

pub(crate) fn id(&self) -> TransactionId {
self.transaction_id.unwrap()
}

fn leak(mut self) -> TransactionId {
self.transaction_id.take().unwrap()
}
Expand All @@ -246,10 +255,13 @@ impl TransactionGuard {
impl Drop for TransactionGuard {
fn drop(&mut self) {
if let Some(transaction_id) = self.transaction_id {
self.transaction_tracker
.lock()
.unwrap()
.deallocate_read_transaction(transaction_id);
if self.write_transaction {
self.transaction_tracker
.end_write_transaction(transaction_id);
} else {
self.transaction_tracker
.deallocate_read_transaction(transaction_id);
}
}
}
}
Expand Down Expand Up @@ -286,10 +298,7 @@ impl Drop for TransactionGuard {
/// ```
pub struct Database {
mem: Arc<TransactionalMemory>,
next_transaction_id: AtomicTransactionId,
transaction_tracker: Arc<Mutex<TransactionTracker>>,
live_write_transaction: Mutex<Option<TransactionId>>,
live_write_transaction_available: Condvar,
transaction_tracker: Arc<TransactionTracker>,
}

impl Database {
Expand All @@ -306,30 +315,6 @@ impl Database {
Self::builder().open(path)
}

pub(crate) fn start_write_transaction(&self) -> TransactionId {
let mut live_write_transaction = self.live_write_transaction.lock().unwrap();
while live_write_transaction.is_some() {
live_write_transaction = self
.live_write_transaction_available
.wait(live_write_transaction)
.unwrap();
}
assert!(live_write_transaction.is_none());
let transaction_id = self.next_transaction_id.next();
#[cfg(feature = "logging")]
info!("Beginning write transaction id={:?}", transaction_id);
*live_write_transaction = Some(transaction_id);

transaction_id
}

pub(crate) fn end_write_transaction(&self, id: TransactionId) {
let mut live_write_transaction = self.live_write_transaction.lock().unwrap();
assert_eq!(live_write_transaction.unwrap(), id);
*live_write_transaction = None;
self.live_write_transaction_available.notify_one();
}

pub(crate) fn get_memory(&self) -> Arc<TransactionalMemory> {
self.mem.clone()
}
Expand Down Expand Up @@ -395,12 +380,7 @@ impl Database {
if txn.list_persistent_savepoints()?.next().is_some() {
return Err(CompactionError::PersistentSavepointExists);
}
if self
.transaction_tracker
.lock()
.unwrap()
.any_savepoint_exists()
{
if self.transaction_tracker.any_savepoint_exists() {
return Err(CompactionError::EphemeralSavepointExists);
}
txn.set_durability(Durability::Paranoid);
Expand Down Expand Up @@ -449,7 +429,7 @@ impl Database {
) -> Result {
let freed_list = Arc::new(Mutex::new(vec![]));
let table_tree = TableTree::new(system_root, mem.clone(), freed_list);
let fake_transaction_tracker = Arc::new(Mutex::new(TransactionTracker::new()));
let fake_transaction_tracker = Arc::new(TransactionTracker::new(TransactionId::new(0)));
if let Some(savepoint_table_def) = table_tree
.get_table::<SavepointId, SerializedSavepoint>(
SAVEPOINT_TABLE.name(),
Expand Down Expand Up @@ -718,18 +698,13 @@ impl Database {

let db = Database {
mem,
next_transaction_id: AtomicTransactionId::new(next_transaction_id),
transaction_tracker: Arc::new(Mutex::new(TransactionTracker::new())),
live_write_transaction: Mutex::new(None),
live_write_transaction_available: Condvar::new(),
transaction_tracker: Arc::new(TransactionTracker::new(next_transaction_id)),
};

// Restore the tracker state for any persistent savepoints
let txn = db.begin_write().map_err(|e| e.into_storage_error())?;
if let Some(next_id) = txn.next_persistent_savepoint_id()? {
db.transaction_tracker
.lock()
.unwrap()
.restore_savepoint_counter_state(next_id);
}
for id in txn.list_persistent_savepoints()? {
Expand All @@ -743,8 +718,6 @@ impl Database {
},
};
db.transaction_tracker
.lock()
.unwrap()
.register_persistent_savepoint(&savepoint);
}
txn.abort()?;
Expand All @@ -753,22 +726,18 @@ impl Database {
}

fn allocate_read_transaction(&self) -> Result<TransactionGuard> {
let mut guard = self.transaction_tracker.lock().unwrap();
let id = self.mem.get_last_committed_transaction_id()?;
guard.register_read_transaction(id);

Ok(TransactionGuard {
transaction_tracker: self.transaction_tracker.clone(),
transaction_id: Some(id),
})
let id = self
.transaction_tracker
.register_read_transaction(&self.mem)?;

Ok(TransactionGuard::new_read(
id,
self.transaction_tracker.clone(),
))
}

pub(crate) fn allocate_savepoint(&self) -> Result<(SavepointId, TransactionId)> {
let id = self
.transaction_tracker
.lock()
.unwrap()
.allocate_savepoint();
let id = self.transaction_tracker.allocate_savepoint();
Ok((id, self.allocate_read_transaction()?.leak()))
}

Expand All @@ -783,7 +752,11 @@ impl Database {
/// write may be in progress at a time. If a write is in progress, this function will block
/// until it completes.
pub fn begin_write(&self) -> Result<WriteTransaction, TransactionError> {
WriteTransaction::new(self, self.transaction_tracker.clone()).map_err(|e| e.into())
let guard = TransactionGuard::new_write(
self.transaction_tracker.start_write_transaction(),
self.transaction_tracker.clone(),
);
WriteTransaction::new(self, guard, self.transaction_tracker.clone()).map_err(|e| e.into())
}

/// Begins a read transaction
Expand Down
Loading

0 comments on commit 10ef672

Please sign in to comment.