Skip to content

Commit

Permalink
Simplify State lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonz-dfinity committed Feb 1, 2025
1 parent 042e60f commit 19088a5
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 113 deletions.
64 changes: 52 additions & 12 deletions rs/backend/src/accounts_store.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! User accounts and transactions.
use crate::accounts_store::schema::accounts_in_unbounded_stable_btree_map::AccountsDbAsUnboundedStableBTreeMap;
use crate::multi_part_transactions_processor::MultiPartTransactionsProcessor;
use crate::state::StableState;
use crate::state::{partitions::PartitionType, with_partitions, StableState};
use crate::stats::Stats;
use candid::CandidType;
use dfn_candid::Candid;
Expand All @@ -17,7 +18,6 @@ use std::collections::{BTreeMap, HashMap, VecDeque};
use std::fmt;
use std::ops::RangeBounds;

pub mod constructors;
pub mod histogram;
pub mod schema;
use schema::{
Expand Down Expand Up @@ -45,7 +45,6 @@ const MAX_IMPORTED_TOKENS: i32 = 20;
pub struct AccountsStore {
// TODO(NNS1-720): Use AccountIdentifier directly as the key for this HashMap
accounts_db: schema::proxy::AccountsDbAsProxy,

accounts_db_stats: AccountsDbStats,
}

Expand Down Expand Up @@ -92,7 +91,7 @@ impl AccountsDbTrait for AccountsStore {
}
}

#[derive(Default, CandidType, Deserialize, Debug, Eq, PartialEq)]
#[derive(Default, CandidType, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct AccountsDbStats {
pub sub_accounts_count: u64,
pub hardware_wallet_accounts_count: u64,
Expand Down Expand Up @@ -318,6 +317,34 @@ pub enum DetachCanisterResponse {
}

impl AccountsStore {
/// Creates a new `AccountsStore`. Should be called in `init`.
pub fn new() -> Self {
let accounts_partition = with_partitions(|partitions| partitions.get(PartitionType::Accounts.memory_id()));
let accounts_db = AccountsDbAsProxy::from(AccountsDb::UnboundedStableBTreeMap(
AccountsDbAsUnboundedStableBTreeMap::new(accounts_partition),
));
let accounts_db_stats = AccountsDbStats::default();

Self {
accounts_db,
accounts_db_stats,
}
}

/// Recovers the state from stable memory. Should be called in `post_upgrade`.
pub fn new_restored(serializable_state: AccountsStoreSerializableState) -> Self {
let AccountsStoreSerializableState { accounts_db_stats } = serializable_state;
let accounts_partition = with_partitions(|partitions| partitions.get(PartitionType::Accounts.memory_id()));
let accounts_db = AccountsDbAsProxy::from(AccountsDb::UnboundedStableBTreeMap(
AccountsDbAsUnboundedStableBTreeMap::load(accounts_partition),
));

Self {
accounts_db,
accounts_db_stats,
}
}

/// Determines whether a migration is being performed.
#[must_use]
#[allow(dead_code)]
Expand Down Expand Up @@ -699,7 +726,26 @@ impl AccountsStore {
}
}

impl StableState for AccountsStore {
/// The component(s) of the `AccountsStore` that are serialized and de-serialized during upgrades.
pub struct AccountsStoreSerializableState {
accounts_db_stats: AccountsDbStats,
}

impl From<AccountsStore> for AccountsStoreSerializableState {
fn from(accounts_store: AccountsStore) -> Self {
let AccountsStore {
// Field(s) persisted by serialization/deserialization through upgrades:
accounts_db_stats,

// Field(s) not persisted by serialization/deserialization through upgrades:
accounts_db: _,
} = accounts_store;

Self { accounts_db_stats }
}
}

impl StableState for AccountsStoreSerializableState {
fn encode(&self) -> Vec<u8> {
// Accounts are now in stable structures and no longer in a simple map
// on the heap. So we don't need to encode them here.
Expand Down Expand Up @@ -778,13 +824,7 @@ impl StableState for AccountsStore {
return Err("Accounts DB stats should be present since the stable structures migration.".to_string());
};

Ok(AccountsStore {
// Because the stable structures migration is finished, accounts_db
// will be replaced with an AccountsDbAsUnboundedStableBTreeMap in
// State::from(Partitions) so it doesn't matter what we set here.
accounts_db: AccountsDbAsProxy::default(),
accounts_db_stats,
})
Ok(Self { accounts_db_stats })
}
}

Expand Down
29 changes: 0 additions & 29 deletions rs/backend/src/accounts_store/constructors.rs

This file was deleted.

2 changes: 2 additions & 0 deletions rs/backend/src/assets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl ContentEncoding {
const LABEL_ASSETS: &[u8] = b"http_assets";

#[derive(Default, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Clone))]
pub struct AssetHashes(RbTree<Vec<u8>, Hash>);

impl From<&Assets> for AssetHashes {
Expand Down Expand Up @@ -115,6 +116,7 @@ impl Asset {
}

#[derive(Default, CandidType, Deserialize, PartialEq, Eq, Debug)]
#[cfg_attr(test, derive(Clone))]
pub struct Assets(HashMap<String, Asset>);

impl Assets {
Expand Down
134 changes: 91 additions & 43 deletions rs/backend/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
pub mod partitions;
#[cfg(test)]
pub mod tests;
mod with_accounts_in_stable_memory;

use self::partitions::{PartitionType, Partitions};
use crate::accounts_store::schema::accounts_in_unbounded_stable_btree_map::AccountsDbAsUnboundedStableBTreeMap;
use crate::accounts_store::schema::proxy::AccountsDb;
use crate::accounts_store::AccountsStore;
use self::partitions::Partitions;
use crate::accounts_store::{AccountsStore, AccountsStoreSerializableState};
use crate::assets::AssetHashes;
use crate::assets::Assets;
use crate::perf::PerformanceCounts;
Expand Down Expand Up @@ -86,7 +83,7 @@ pub fn restore_state() {
/// # Panics
/// Panics when the function is called before the `init_state` or `restore_state` is called.
pub fn save_state() {
STATE.with_borrow(|s| s.as_ref().expect("State not initialized").save());
STATE.take().expect("State not initialized").save();
}

/// An accessor for the state.
Expand Down Expand Up @@ -117,45 +114,60 @@ pub fn reset_partitions() {
PARTITIONS.replace(Partitions::from(DefaultMemoryImpl::default()));
}

#[allow(clippy::new_without_default)]
impl State {
/// Creates new state. Should be called in `init`.
#[must_use]
pub fn new() -> Self {
let accounts_partition = with_partitions(|p| p.get(PartitionType::Accounts.memory_id()));
let accounts_store = AccountsStore::from(AccountsDb::UnboundedStableBTreeMap(
AccountsDbAsUnboundedStableBTreeMap::new(accounts_partition),
));
State {
/// Represents the parts of the `State` that are serialized and de-serialized during upgrades.
struct SerializableState {
accounts_store: AccountsStoreSerializableState,
assets: Assets,
tvl_state: TvlState,
}

impl From<State> for SerializableState {
fn from(state: State) -> Self {
// Destructure and consume the state.
let State {
// Field(s) persisted by serialization/deserialization through upgrades:
accounts_store,
assets: Assets::default(),
asset_hashes: AssetHashes::default(),
performance: PerformanceCounts::default(),
tvl_state: TvlState::default(),
assets,
tvl_state,

// Field(s) not persisted by serialization/deserialization through upgrades:
asset_hashes: _,
performance: _,
} = state;

let accounts_store = AccountsStoreSerializableState::from(accounts_store);

SerializableState {
accounts_store,
assets,
tvl_state,
}
}
}

/// Recovers the state from stable memory. Should be called in `post_upgrade`.
#[must_use]
pub fn new_restored() -> Self {
println!("START state::new_restored: ())");
let accounts_partition = with_partitions(|p| p.get(PartitionType::Accounts.memory_id()));
let mut state = Self::recover_heap_from_managed_memory();
let accounts_db =
AccountsDb::UnboundedStableBTreeMap(AccountsDbAsUnboundedStableBTreeMap::load(accounts_partition));
// Replace the default accountsdb created by `serde` with the one from stable memory.
let _deserialized_accounts_db = state.accounts_store.replace_accounts_db(accounts_db);
println!("END state::new_restored: ()");
state
}
impl From<SerializableState> for State {
fn from(serializable: SerializableState) -> Self {
let SerializableState {
accounts_store,
assets,
tvl_state,
} = serializable;

let accounts_store = AccountsStore::new_restored(accounts_store);
let asset_hashes = AssetHashes::from(&assets);
let performance = PerformanceCounts::default();

/// Saves the state to stable memory. Should be called in `pre_upgrade`.
pub fn save(&self) {
self.save_heap_to_managed_memory();
Self {
accounts_store,
assets,
asset_hashes,
performance,
tvl_state,
}
}
}

impl StableState for State {
impl StableState for SerializableState {
fn encode(&self) -> Vec<u8> {
Candid((
self.accounts_store.encode(),
Expand All @@ -169,17 +181,53 @@ impl StableState for State {
fn decode(bytes: Vec<u8>) -> Result<Self, String> {
let (account_store_bytes, assets_bytes, tvl_state_bytes) = Candid::from_bytes(bytes).map(|c| c.0)?;

let accounts_store = AccountsStoreSerializableState::decode(account_store_bytes)?;
let assets = Assets::decode(assets_bytes)?;
let asset_hashes = AssetHashes::from(&assets);
let performance = PerformanceCounts::default();
let tvl_state = TvlState::decode(tvl_state_bytes)?;

Ok(State {
accounts_store: AccountsStore::decode(account_store_bytes)?,
Ok(Self {
accounts_store,
assets,
asset_hashes,
performance,
tvl_state,
})
}
}

#[allow(clippy::new_without_default)]
impl State {
/// Creates new state. Should be called in `init`.
#[must_use]
pub fn new() -> Self {
State {
accounts_store: AccountsStore::new(),
assets: Assets::default(),
asset_hashes: AssetHashes::default(),
performance: PerformanceCounts::default(),
tvl_state: TvlState::default(),
}
}

/// Recovers the state from stable memory. Should be called in `post_upgrade`.
///
/// # Panics
/// Panics if the state cannot be decoded.
#[must_use]
pub fn new_restored() -> Self {
println!("START state::new_restored: ())");
let candid_bytes = with_partitions(Partitions::read_bytes_from_managed_memory);
let decoded = SerializableState::decode(candid_bytes).unwrap_or_else(|e| {
panic!("Decoding stable memory failed. Error: {e:?}");
});
let state = State::from(decoded);
println!("END state::new_restored: ()");
state
}

// Saves the state. Should be called in `pre_upgrade`.
fn save(self) {
let serializable = SerializableState::from(self);
println!("START state::save_heap: ()");
let bytes = serializable.encode();
with_partitions(|partitions| partitions.write_bytes_to_managed_memory(&bytes));
}
}
2 changes: 1 addition & 1 deletion rs/backend/src/state/partitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl Partitions {
}

/// Writes, growing the memory if necessary.
pub fn growing_write(&self, memory_id: MemoryId, offset: u64, bytes: &[u8]) {
fn growing_write(&self, memory_id: MemoryId, offset: u64, bytes: &[u8]) {
let memory = self.get(memory_id);
let min_pages: u64 = u64::try_from(bytes.len())
.unwrap_or_else(|err| unreachable!("Buffer for growing_write is longer than 2**64 bytes?? Err: {err}"))
Expand Down
Loading

0 comments on commit 19088a5

Please sign in to comment.