Skip to content

Commit

Permalink
Simplify State lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonz-dfinity committed Jan 31, 2025
1 parent 00893e4 commit 3b7de2f
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 108 deletions.
108 changes: 73 additions & 35 deletions rs/backend/src/accounts_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ const MAX_IMPORTED_TOKENS: i32 = 20;
pub struct AccountsStore {
// TODO(NNS1-720): Use AccountIdentifier directly as the key for this HashMap
accounts_db: schema::proxy::AccountsDbAsProxy,

accounts_db_stats: AccountsDbStats,
}

Expand Down Expand Up @@ -92,7 +91,7 @@ impl AccountsDbTrait for AccountsStore {
}
}

#[derive(Default, CandidType, Deserialize, Debug, Eq, PartialEq)]
#[derive(Default, CandidType, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct AccountsDbStats {
pub sub_accounts_count: u64,
pub hardware_wallet_accounts_count: u64,
Expand Down Expand Up @@ -699,13 +698,74 @@ impl AccountsStore {
}
}

impl StableState for AccountsStore {
/// The component(s) of the `AccountsStore` that are serialized and deserialized during upgrades.
pub struct AccountsStoreSerializableState {
accounts_db_stats: AccountsDbStats,
}

impl From<AccountsStore> for AccountsStoreSerializableState {
fn from(accounts_store: AccountsStore) -> Self {
let AccountsStore {
// Field(s) persisted by serialization/deserialization through upgrades:
accounts_db_stats,

// Field(s) not persisted by serialization/deserialization through upgrades:
accounts_db: _,
} = accounts_store;

Self { accounts_db_stats }
}
}

impl From<(AccountsStoreSerializableState, AccountsDbAsProxy)> for AccountsStore {
fn from((serializable, accounts_db): (AccountsStoreSerializableState, AccountsDbAsProxy)) -> Self {
let AccountsStoreSerializableState { accounts_db_stats } = serializable;

Self {
accounts_db,
accounts_db_stats,
}
}
}

#[allow(clippy::zero_sized_map_values)]
type SerializableStateForEncoding = (
BTreeMap<Vec<u8>, candid::Empty>,
HashMap<AccountIdentifier, AccountWrapper>,
HashMap<(AccountIdentifier, AccountIdentifier), candid::Empty>,
VecDeque<candid::Empty>,
HashMap<AccountIdentifier, candid::Empty>,
Option<BlockIndex>,
MultiPartTransactionsProcessor,
u64,
u64,
Option<AccountsDbStats>,
);
type SerializableStateForDecoding = (
candid::Reserved,
// TODO: Change to candid:Reserved and remove AccountWrapper after
// we've deployed to mainnet. If we do it now, decoding will break
// because of the skip quota. The decoder will think there is an
// attack with a lot of unnecessary data in a request.
HashMap<AccountIdentifier, AccountWrapper>,
candid::Reserved,
candid::Reserved,
candid::Reserved,
candid::Reserved,
candid::Reserved,
candid::Reserved,
candid::Reserved,
Option<AccountsDbStats>,
);

impl StableState for AccountsStoreSerializableState {
fn encode(&self) -> Vec<u8> {
// Accounts are now in stable structures and no longer in a simple map
// on the heap. So we don't need to encode them here.
let empty_accounts = BTreeMap::<Vec<u8>, candid::Empty>::new();
Candid((
empty_accounts,
let AccountsStoreSerializableState { accounts_db_stats } = self;

let candid_object: SerializableStateForEncoding = (
// Accounts are now in stable structures and no longer in a simple map
// on the heap. So we don't need to encode them here.
BTreeMap::<Vec<u8>, candid::Empty>::new(),
// hardware_wallets_and_sub_accounts is unused but we need to encode
// it for backwards compatibility.
// TODO: Change AccountWrapper to candid::Empty after we've
Expand Down Expand Up @@ -734,10 +794,9 @@ impl StableState for AccountsStore {
// neurons_topped_up_count is unused but we need to encode
// it for backwards compatibility.
0u64,
Some(&self.accounts_db_stats),
))
.into_bytes()
.unwrap()
Some(accounts_db_stats.clone()),
);
Candid(candid_object).into_bytes().unwrap()
}

fn decode(bytes: Vec<u8>) -> Result<Self, String> {
Expand All @@ -757,34 +816,13 @@ impl StableState for AccountsStore {
_last_ledger_sync_timestamp_nanos,
_neurons_topped_up_count,
accounts_db_stats_maybe,
): (
candid::Reserved,
// TODO: Change to candid:Reserved and remove AccountWrapper after
// we've deployed to mainnet. If we do it now, decoding will break
// because of the skip quota. The decoder will think there is an
// attack with a lot of unnecessary data in a request.
HashMap<AccountIdentifier, AccountWrapper>,
candid::Reserved,
candid::Reserved,
candid::Reserved,
candid::Reserved,
candid::Reserved,
candid::Reserved,
candid::Reserved,
Option<AccountsDbStats>,
) = Candid::from_bytes(bytes).map(|c| c.0)?;
): SerializableStateForDecoding = Candid::from_bytes(bytes).map(|c| c.0)?;

let Some(accounts_db_stats) = accounts_db_stats_maybe else {
return Err("Accounts DB stats should be present since the stable structures migration.".to_string());
};

Ok(AccountsStore {
// Because the stable structures migration is finished, accounts_db
// will be replaced with an AccountsDbAsUnboundedStableBTreeMap in
// State::from(Partitions) so it doesn't matter what we set here.
accounts_db: AccountsDbAsProxy::default(),
accounts_db_stats,
})
Ok(Self { accounts_db_stats })
}
}

Expand Down
2 changes: 2 additions & 0 deletions rs/backend/src/assets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl ContentEncoding {
const LABEL_ASSETS: &[u8] = b"http_assets";

#[derive(Default, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Clone))]
pub struct AssetHashes(RbTree<Vec<u8>, Hash>);

impl From<&Assets> for AssetHashes {
Expand Down Expand Up @@ -115,6 +116,7 @@ impl Asset {
}

#[derive(Default, CandidType, Deserialize, PartialEq, Eq, Debug)]
#[cfg_attr(test, derive(Clone))]
pub struct Assets(HashMap<String, Asset>);

impl Assets {
Expand Down
147 changes: 103 additions & 44 deletions rs/backend/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
pub mod partitions;
#[cfg(test)]
pub mod tests;
mod with_accounts_in_stable_memory;

use self::partitions::{PartitionType, Partitions};
use crate::accounts_store::schema::accounts_in_unbounded_stable_btree_map::AccountsDbAsUnboundedStableBTreeMap;
use crate::accounts_store::schema::proxy::AccountsDb;
use crate::accounts_store::AccountsStore;
use crate::accounts_store::schema::proxy::{AccountsDb, AccountsDbAsProxy};
use crate::accounts_store::{AccountsStore, AccountsStoreSerializableState};
use crate::assets::AssetHashes;
use crate::assets::Assets;
use crate::perf::PerformanceCounts;
use crate::tvl::state::TvlState;

use dfn_candid::Candid;
use candid::{Decode, Encode};
use ic_cdk::println;
use ic_stable_structures::DefaultMemoryImpl;
use on_wire::{FromWire, IntoWire};
use std::cell::RefCell;

pub struct State {
Expand Down Expand Up @@ -86,7 +84,7 @@ pub fn restore_state() {
/// # Panics
/// Panics when the function is called before the `init_state` or `restore_state` is called.
pub fn save_state() {
STATE.with_borrow(|s| s.as_ref().expect("State not initialized").save());
STATE.take().expect("State not initialized").save();
}

/// An accessor for the state.
Expand Down Expand Up @@ -117,6 +115,86 @@ pub fn reset_partitions() {
PARTITIONS.replace(Partitions::from(DefaultMemoryImpl::default()));
}

/// Represents the parts of the `State` that are serialized and deserialized during upgrades.
struct SerializableState {
accounts_store: AccountsStoreSerializableState,
assets: Assets,
tvl_state: TvlState,
}

impl From<State> for SerializableState {
fn from(state: State) -> Self {
// Destructure and consume the state.
let State {
// Field(s) persisted by serialization/deserialization through upgrades:
accounts_store,
assets,
tvl_state,

// Field(s) not persisted by serialization/deserialization through upgrades:
asset_hashes: _,
performance: _,
} = state;

let accounts_store = AccountsStoreSerializableState::from(accounts_store);

SerializableState {
accounts_store,
assets,
tvl_state,
}
}
}

impl From<(SerializableState, AccountsDbAsProxy)> for State {
fn from((serializable, accounts_db): (SerializableState, AccountsDbAsProxy)) -> Self {
let SerializableState {
accounts_store,
assets,
tvl_state,
} = serializable;

let accounts_store = AccountsStore::from((accounts_store, accounts_db));
let asset_hashes = AssetHashes::from(&assets);
let performance = PerformanceCounts::default();

Self {
accounts_store,
assets,
asset_hashes,
performance,
tvl_state,
}
}
}

type SerializableStateCandidType = (Vec<u8>, Vec<u8>, Vec<u8>);

impl StableState for SerializableState {
fn encode(&self) -> Vec<u8> {
let candid_type: SerializableStateCandidType = (
self.accounts_store.encode(),
self.assets.encode(),
self.tvl_state.encode(),
);
Encode!(&candid_type).unwrap()
}

fn decode(bytes: Vec<u8>) -> Result<Self, String> {
let (account_store_bytes, assets_bytes, tvl_state_bytes): SerializableStateCandidType =
Decode!(&bytes, SerializableStateCandidType).map_err(|e| e.to_string())?;
let accounts_store = AccountsStoreSerializableState::decode(account_store_bytes)?;
let assets = Assets::decode(assets_bytes)?;
let tvl_state = TvlState::decode(tvl_state_bytes)?;

Ok(SerializableState {
accounts_store,
assets,
tvl_state,
})
}
}

#[allow(clippy::new_without_default)]
impl State {
/// Creates new state. Should be called in `init`.
Expand All @@ -136,50 +214,31 @@ impl State {
}

/// Recovers the state from stable memory. Should be called in `post_upgrade`.
///
/// # Panics
/// Panics if the state cannot be decoded.
#[must_use]
pub fn new_restored() -> Self {
println!("START state::new_restored: ())");
let accounts_partition = with_partitions(|p| p.get(PartitionType::Accounts.memory_id()));
let mut state = Self::recover_heap_from_managed_memory();
let accounts_db =
AccountsDb::UnboundedStableBTreeMap(AccountsDbAsUnboundedStableBTreeMap::load(accounts_partition));
// Replace the default accountsdb created by `serde` with the one from stable memory.
let _deserialized_accounts_db = state.accounts_store.replace_accounts_db(accounts_db);
let candid_bytes = with_partitions(Partitions::read_bytes_from_managed_memory);

let decoded = SerializableState::decode(candid_bytes).unwrap_or_else(|e| {
panic!("Decoding stable memory failed. Error: {e:?}");
});
let accounts_partition = with_partitions(|partitions| partitions.get(PartitionType::Accounts.memory_id()));
let accounts_db = AccountsDbAsProxy::from(AccountsDb::UnboundedStableBTreeMap(
AccountsDbAsUnboundedStableBTreeMap::load(accounts_partition),
));
let state = State::from((decoded, accounts_db));
println!("END state::new_restored: ()");
state
}

/// Saves the state to stable memory. Should be called in `pre_upgrade`.
pub fn save(&self) {
self.save_heap_to_managed_memory();
}
}

impl StableState for State {
fn encode(&self) -> Vec<u8> {
Candid((
self.accounts_store.encode(),
self.assets.encode(),
self.tvl_state.encode(),
))
.into_bytes()
.unwrap()
}

fn decode(bytes: Vec<u8>) -> Result<Self, String> {
let (account_store_bytes, assets_bytes, tvl_state_bytes) = Candid::from_bytes(bytes).map(|c| c.0)?;

let assets = Assets::decode(assets_bytes)?;
let asset_hashes = AssetHashes::from(&assets);
let performance = PerformanceCounts::default();
let tvl_state = TvlState::decode(tvl_state_bytes)?;

Ok(State {
accounts_store: AccountsStore::decode(account_store_bytes)?,
assets,
asset_hashes,
performance,
tvl_state,
})
// Saves the state. Should be called in `pre_upgrade`.
fn save(self) {
let serializable = SerializableState::from(self);
println!("START state::save_heap: ()");
let bytes = serializable.encode();
with_partitions(|partitions| partitions.write_bytes_to_managed_memory(&bytes));
}
}
2 changes: 1 addition & 1 deletion rs/backend/src/state/partitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl Partitions {
}

/// Writes, growing the memory if necessary.
pub fn growing_write(&self, memory_id: MemoryId, offset: u64, bytes: &[u8]) {
fn growing_write(&self, memory_id: MemoryId, offset: u64, bytes: &[u8]) {
let memory = self.get(memory_id);
let min_pages: u64 = u64::try_from(bytes.len())
.unwrap_or_else(|err| unreachable!("Buffer for growing_write is longer than 2**64 bytes?? Err: {err}"))
Expand Down
Loading

0 comments on commit 3b7de2f

Please sign in to comment.