diff --git a/framework/Cargo.toml b/framework/Cargo.toml index 9235e2373..83acd2f17 100644 --- a/framework/Cargo.toml +++ b/framework/Cargo.toml @@ -17,7 +17,7 @@ util = { path = "../built-in-services/util"} hasher = { version = "0.1", features = ['hash-keccak'] } cita_trie = "2.0" bytes = "0.5" -derive_more = "0.15" +derive_more = "0.99" rocksdb = "0.12" lazy_static = "1.4" byteorder = "1.3" @@ -27,9 +27,12 @@ json = "0.12" hex = "0.4" serde_json = "1.0" log = "0.4" +rayon = "1.3" [dev-dependencies] async-trait = "0.1" toml = "0.5" binding-macro = { path = "../binding-macro" } +rand = "0.7" serde = { version = "1.0", features = ["derive"] } +muta-codec-derive = "0.2" diff --git a/framework/src/binding/sdk/mod.rs b/framework/src/binding/sdk/mod.rs index b385bfdd1..e3892b792 100644 --- a/framework/src/binding/sdk/mod.rs +++ b/framework/src/binding/sdk/mod.rs @@ -40,7 +40,10 @@ impl ServiceSDK for DefalutServiceSDK { // Alloc or recover a `Map` by` var_name` - fn alloc_or_recover_map( + fn alloc_or_recover_map< + K: 'static + Send + FixedCodec + Clone + PartialEq, + V: 'static + FixedCodec, + >( &mut self, var_name: &str, ) -> Box> { diff --git a/framework/src/binding/store/map.rs b/framework/src/binding/store/map.rs index 4d4c52dfb..ef8d83563 100644 --- a/framework/src/binding/store/map.rs +++ b/framework/src/binding/store/map.rs @@ -4,107 +4,203 @@ use std::marker::PhantomData; use std::rc::Rc; use bytes::Bytes; +use rayon::prelude::*; use protocol::fixed_codec::FixedCodec; use protocol::traits::{ServiceState, StoreMap}; use protocol::types::Hash; use protocol::{ProtocolError, ProtocolResult}; -use crate::binding::store::{FixedKeys, StoreError}; +use crate::binding::store::{get_bucket_index, Bucket, FixedBuckets, StoreError}; pub struct DefaultStoreMap { state: Rc>, - var_name: Hash, - keys: FixedKeys, + var_name: String, + keys: RefCell>, + len_key: Bytes, + len: u32, phantom: PhantomData, } -impl - DefaultStoreMap +impl DefaultStoreMap +where + S: 'static + ServiceState, + K: 'static + Send + FixedCodec + PartialEq, + V: 'static + FixedCodec, { pub fn new(state: Rc>, name: &str) -> Self { - let var_name = Hash::digest(Bytes::from(name.to_owned() + "map")); - - let opt_bs: Option = state + let len_key = Bytes::from(name.to_string() + "_map_len"); + let len = state .borrow() - .get(&var_name) - .expect("get map should not fail"); - - let keys = if let Some(bs) = opt_bs { - <_>::decode_fixed(bs).expect("decode keys should not fail") - } else { - FixedKeys { inner: Vec::new() } - }; + .get(&len_key) + .expect("Get len failed") + .unwrap_or(0u32); - Self { + DefaultStoreMap { state, - var_name, - keys, + len_key, + len, + var_name: name.to_string(), + keys: RefCell::new(FixedBuckets::new()), phantom: PhantomData, } } - fn get_map_key(&self, key: &K) -> ProtocolResult { - let mut name_bytes = self.var_name.as_bytes().to_vec(); - name_bytes.extend_from_slice(key.encode_fixed()?.as_ref()); + fn inner_insert(&mut self, key: K, value: V) -> ProtocolResult<()> { + let key_bytes = key.encode_fixed()?; + let mk = self.get_map_key(&key_bytes); + let bkt_idx = get_bucket_index(&key_bytes); - Ok(Hash::digest(Bytes::from(name_bytes))) + if !self.inner_contains(bkt_idx, &key)? { + self.keys.borrow_mut().insert(bkt_idx, key); + + self.state.borrow_mut().insert( + self.get_bucket_name(bkt_idx), + self.keys.borrow().get_bucket(bkt_idx).encode_fixed()?, + )?; + self.len_add_one()?; + } + self.state.borrow_mut().insert(mk, value) } fn inner_get(&self, key: &K) -> ProtocolResult> { - if self.keys.inner.contains(key) { - let mk = self.get_map_key(key)?; - self.state.borrow().get(&mk)?.map_or_else( - || { - Ok(Some(<_>::decode_fixed(Bytes::new()).map_err(|_| { - ProtocolError::from(StoreError::DecodeError) - })?)) - }, - |v| Ok(Some(v)), - ) + let key_bytes = key.encode_fixed()?; + let bkt_idx = get_bucket_index(&key_bytes); + + if self.inner_contains(bkt_idx, &key)? { + self.state + .borrow() + .get(&self.get_map_key(&key_bytes))? + .map_or_else( + || { + Ok(Some(<_>::decode_fixed(Bytes::new()).map_err(|_| { + ProtocolError::from(StoreError::DecodeError) + })?)) + }, + |v| Ok(Some(v)), + ) } else { Ok(None) } } - // TODO(@zhounan): Atomicity of insert(k, v) and insert self.keys to - // ServiceState is not guaranteed for now That must be settled soon after. - fn inner_insert(&mut self, key: K, value: V) -> ProtocolResult<()> { - let mk = self.get_map_key(&key)?; + fn inner_remove(&mut self, key: &K) -> ProtocolResult> { + let key_bytes = key.encode_fixed()?; + let bkt_idx = get_bucket_index(&key_bytes); - if !self.contains(&key) { - self.keys.inner.push(key); + if self.inner_contains(bkt_idx, &key)? { + let value = self.inner_get(key)?.expect("value should be existed"); + let bkt_idx = get_bucket_index(&key_bytes); + let bkt_name = self.get_bucket_name(bkt_idx); + + let _ = self.keys.borrow_mut().remove_item(bkt_idx, key)?; + self.state.borrow_mut().insert( + bkt_name, + self.keys.borrow().get_bucket(bkt_idx).encode_fixed()?, + )?; self.state .borrow_mut() - .insert(self.var_name.clone(), self.keys.encode_fixed()?)?; + .insert(self.get_map_key(&key_bytes), Bytes::new())?; + self.len_sub_one()?; + Ok(Some(value)) + } else { + Ok(None) } - - self.state.borrow_mut().insert(mk, value) } - // TODO(@zhounan): Atomicity of insert(k, v) and insert self.keys to - // ServiceState is not guaranteed for now That must be settled soon after. - fn inner_remove(&mut self, key: &K) -> ProtocolResult> { - if self.contains(key) { - let value: V = self.inner_get(key)?.expect("value should be existed"); - self.keys.inner.remove_item(key); - self.state - .borrow_mut() - .insert(self.var_name.clone(), self.keys.encode_fixed()?)?; + #[inline(always)] + fn inner_contains(&self, bkt_idx: usize, key: &K) -> ProtocolResult { + if self.keys.borrow().is_bucket_recovered(bkt_idx) { + return Ok(self.keys.borrow().contains(bkt_idx, key)); + } - self.state - .borrow_mut() - .insert(self.get_map_key(key)?, Bytes::new())?; + let bkt = if let Some(bytes) = self.state.borrow().get(&self.get_bucket_name(bkt_idx))? { + <_>::decode_fixed(bytes)? + } else { + Bucket::new() + }; - Ok(Some(value)) + let ret = bkt.contains(key); + self.keys.borrow_mut().recover_bucket(bkt_idx, bkt); + Ok(ret) + } + + fn get_map_key(&self, key_bytes: &Bytes) -> Bytes { + let mut name_bytes = self.var_name.as_bytes().to_vec(); + name_bytes.extend_from_slice(key_bytes); + + if key_bytes.len() > 32 { + Hash::digest(Bytes::from(name_bytes)).as_bytes() } else { - Ok(None) + Bytes::from(name_bytes) + } + } + + fn get_bucket_name(&self, index: usize) -> Bytes { + let mut bytes = (self.var_name.clone() + "_bucket_").as_bytes().to_vec(); + bytes.extend_from_slice(&index.to_le_bytes()); + Bytes::from(bytes) + } + + fn len_add_one(&mut self) -> ProtocolResult<()> { + self.len += 1; + self.state + .borrow_mut() + .insert(self.len_key.clone(), self.len.encode_fixed()?) + } + + fn len_sub_one(&mut self) -> ProtocolResult<()> { + self.len -= 1; + self.state + .borrow_mut() + .insert(self.len_key.clone(), self.len.encode_fixed()?) + } + + fn recover_all_buckets(&self) { + let idxs = self + .keys + .borrow() + .is_recovered + .iter() + .enumerate() + .filter_map(|(i, &res)| if !res { Some(i) } else { None }) + .collect::>(); + + let opt_bytes = idxs + .iter() + .map(|idx| { + let name = self.get_bucket_name(*idx); + self.state.borrow().get(&name).unwrap() + }) + .collect::>(); + + let buckets = opt_bytes + .into_par_iter() + .map(|bytes| { + if let Some(bs) = bytes { + <_>::decode_fixed(bs).expect("Decode bucket failed") + } else { + Bucket::new() + } + }) + .collect::>(); + + for (idx, bkt) in idxs.into_iter().zip(buckets.into_iter()) { + self.keys.borrow_mut().recover_bucket(idx, bkt); } } + + #[cfg(test)] + fn get_buckets(self) -> FixedBuckets { + self.keys.into_inner() + } } -impl - StoreMap for DefaultStoreMap +impl StoreMap for DefaultStoreMap +where + S: 'static + ServiceState, + K: 'static + Send + FixedCodec + Clone + PartialEq, + V: 'static + FixedCodec, { fn get(&self, key: &K) -> Option { self.inner_get(key) @@ -122,27 +218,29 @@ impl bool { - self.keys.inner.contains(key) + if let Ok(bytes) = key.encode_fixed() { + self.inner_contains(get_bucket_index(&bytes), &key) + .unwrap_or(false) + } else { + false + } } fn len(&self) -> u32 { - self.keys.inner.len() as u32 + self.len } fn is_empty(&self) -> bool { - if let 0 = self.len() { - true - } else { - false - } + self.len == 0 } - fn iter<'a>(&'a self) -> Box + 'a> { - Box::new(MapIter::::new(0, self)) + fn iter<'a>(&'a self) -> Box + 'a> { + self.recover_all_buckets(); + Box::new(NewMapIter::::new(0, self)) } } -pub struct MapIter< +pub struct NewMapIter< 'a, S: 'static + ServiceState, K: 'static + FixedCodec + PartialEq, @@ -152,39 +250,104 @@ pub struct MapIter< map: &'a DefaultStoreMap, } -impl< - 'a, - S: 'static + ServiceState, - K: 'static + FixedCodec + PartialEq, - V: 'static + FixedCodec, - > MapIter<'a, S, K, V> +impl<'a, S, K, V> NewMapIter<'a, S, K, V> +where + S: 'static + ServiceState, + K: 'static + FixedCodec + PartialEq, + V: 'static + FixedCodec, { pub fn new(idx: u32, map: &'a DefaultStoreMap) -> Self { Self { idx, map } } } -impl< - 'a, - S: 'static + ServiceState, - K: 'static + FixedCodec + PartialEq, - V: 'static + FixedCodec, - > Iterator for MapIter<'a, S, K, V> +impl<'a, S, K, V> Iterator for NewMapIter<'a, S, K, V> +where + S: 'static + ServiceState, + K: 'static + Send + FixedCodec + Clone + PartialEq, + V: 'static + FixedCodec, { - type Item = (&'a K, V); + type Item = (K, V); fn next(&mut self) -> Option { - if self.idx < self.map.len() { - let key = self - .map - .keys - .inner - .get(self.idx as usize) - .expect("get key should not fail"); - self.idx += 1; - Some((key, self.map.get(key).expect("get value should not fail"))) - } else { - None + let idx = self.idx; + if idx >= self.map.len { + return None; + } + + for i in 0..16 { + let (left, right) = self.map.keys.borrow().get_abs_index_interval(i); + if left <= idx && idx < right { + let index = idx - left; + let key = self.map.keys.borrow().keys_bucket[i] + .0 + .get(index as usize) + .cloned() + .expect("get key should not fail"); + + self.idx += 1; + return Some(( + key.clone(), + self.map.get(&key).expect("get value should not fail"), + )); + } } + None + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use cita_trie::MemoryDB; + use rand::random; + + use crate::binding::state::{GeneralServiceState, MPTTrie}; + use crate::binding::store::map::DefaultStoreMap; + + use super::*; + + fn gen_bytes() -> Bytes { + Bytes::from((0..16).map(|_| random::()).collect::>()) + } + + #[test] + fn test_map_and_bucket() { + let state = Rc::new(RefCell::new(GeneralServiceState::new(MPTTrie::new( + Arc::new(MemoryDB::new(false)), + )))); + let mut map = DefaultStoreMap::<_, Bytes, Bytes>::new(Rc::clone(&state), "test"); + let key_1 = gen_bytes(); + let val_1 = gen_bytes(); + let key_2 = gen_bytes(); + let val_2 = gen_bytes(); + let key_idx_1 = get_bucket_index(&key_1.encode_fixed().unwrap()); + let key_idx_2 = get_bucket_index(&key_2.encode_fixed().unwrap()); + + map.insert(key_1, val_1); + map.insert(key_2, val_2); + + assert_eq!(map.len(), 2); + + let fbkt = map.get_buckets(); + assert!(fbkt.is_recovered[key_idx_1]); + assert!(fbkt.is_recovered[key_idx_2]); + assert_eq!(fbkt.len(), 2); + + let max = key_idx_1.max(key_idx_2); + let min = key_idx_1.min(key_idx_2); + let res = (0..17) + .map(|i| { + if i > max { + 2u32 + } else if i > min { + 1u32 + } else { + 0u32 + } + }) + .collect::>(); + assert_eq!(fbkt.bucket_lens, res); } } diff --git a/framework/src/binding/store/mod.rs b/framework/src/binding/store/mod.rs index 92f02fae3..d669c2ed6 100644 --- a/framework/src/binding/store/mod.rs +++ b/framework/src/binding/store/mod.rs @@ -53,6 +53,173 @@ impl FixedCodec for FixedKeys { } } +pub struct FixedBuckets { + pub keys_bucket: Vec>, + pub bucket_lens: Vec, + pub is_recovered: Vec, +} + +impl FixedBuckets { + fn new() -> Self { + let mut keys_bucket = Vec::new(); + let mut bucket_lens = vec![0]; + let mut is_recovered = Vec::new(); + + for _i in 0..16 { + keys_bucket.push(Bucket::new()); + bucket_lens.push(0u32); + is_recovered.push(false); + } + + FixedBuckets { + keys_bucket, + bucket_lens, + is_recovered, + } + } + + fn recover_bucket(&mut self, index: usize, bucket: Bucket) { + self.keys_bucket[index] = bucket; + self.is_recovered[index] = true; + self.update_index_interval(index); + } + + fn insert(&mut self, index: usize, key: K) { + let bkt = self.keys_bucket.get_mut(index).unwrap(); + bkt.push(key); + self.update_index_interval(index); + } + + fn contains(&self, index: usize, key: &K) -> bool { + self.keys_bucket[index].contains(key) + } + + fn remove_item(&mut self, index: usize, key: &K) -> ProtocolResult { + let bkt = self.keys_bucket.get_mut(index).unwrap(); + if bkt.contains(key) { + let val = bkt.remove_item(key)?; + self.update_index_interval(index); + Ok(val) + } else { + Err(StoreError::GetNone.into()) + } + } + + fn get_bucket(&self, index: usize) -> &Bucket { + self.keys_bucket + .get(index) + .expect("index must less than 16") + } + + /// The function will panic when index is greater than or equal 16. + fn get_abs_index_interval(&self, index: usize) -> (u32, u32) { + (self.bucket_lens[index], self.bucket_lens[index + 1]) + } + + fn is_bucket_recovered(&self, index: usize) -> bool { + self.is_recovered[index] + } + + fn update_index_interval(&mut self, index: usize) { + let start = index + 1; + let mut acc = self.bucket_lens[index]; + + for i in start..17 { + acc += self.keys_bucket[i - 1].len() as u32; + self.bucket_lens[i] = acc; + } + } + + #[cfg(test)] + fn len(&self) -> u32 { + self.bucket_lens[16] + } + + #[cfg(test)] + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +pub struct Bucket(Vec); + +impl Bucket { + fn new() -> Self { + Bucket(Vec::new()) + } + + fn len(&self) -> usize { + self.0.len() + } + + fn contains(&self, x: &K) -> bool { + self.0.contains(x) + } + + fn push(&mut self, value: K) { + self.0.push(value); + } + + fn remove_item(&mut self, key: &K) -> ProtocolResult { + let mut idx = self.len(); + for (i, item) in self.0.iter().enumerate() { + if item == key { + idx = i; + break; + } + } + + if idx < self.len() { + Ok(self.0.remove(idx)) + } else { + Err(StoreError::GetNone.into()) + } + } +} + +impl rlp::Encodable for Bucket { + fn rlp_append(&self, s: &mut rlp::RlpStream) { + let inner: Vec> = self + .0 + .iter() + .map(|k| k.encode_fixed().expect("encode should not fail").to_vec()) + .collect(); + + s.begin_list(1).append_list::, _>(&inner); + } +} + +impl rlp::Decodable for Bucket { + fn decode(r: &rlp::Rlp) -> Result { + let inner_u8: Vec> = rlp::decode_list(r.at(0)?.as_raw()); + + let inner_k: Result, _> = inner_u8 + .into_iter() + .map(|v| <_>::decode_fixed(Bytes::from(v))) + .collect(); + + let inner = inner_k.map_err(|_| rlp::DecoderError::Custom("decode K from bytes fail"))?; + + Ok(Bucket(inner)) + } +} + +impl FixedCodec for Bucket { + fn encode_fixed(&self) -> ProtocolResult { + Ok(Bytes::from(rlp::encode(self))) + } + + fn decode_fixed(bytes: Bytes) -> ProtocolResult { + Ok(rlp::decode(bytes.as_ref()).map_err(FixedCodecError::from)?) + } +} + +#[inline(always)] +fn get_bucket_index(bytes: &Bytes) -> usize { + let len = bytes.len() - 1; + (bytes[len] >> 4) as usize +} + #[derive(Debug, Display, From)] pub enum StoreError { #[display(fmt = "the key not existed")] @@ -75,3 +242,80 @@ impl From for ProtocolError { ProtocolError::new(ProtocolErrorKind::Binding, Box::new(err)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_insert() { + let mut buckets = FixedBuckets::new(); + assert!(buckets.is_empty()); + + for i in 0..=255u8 { + let key = Bytes::from(vec![i]); + buckets.insert(get_bucket_index(&key), key); + } + + println!("{:?}", buckets.bucket_lens); + + let intervals = (0u32..=16).map(|i| i * 16).collect::>(); + assert!(intervals == buckets.bucket_lens); + assert!(buckets.len() == 256); + + for i in 0..16 { + assert!(buckets.get_bucket(i).len() == 16); + } + + let mut buckets = FixedBuckets::new(); + for i in 0..8 { + let key = Bytes::from(vec![i]); + buckets.insert(get_bucket_index(&key), key); + } + + assert!(buckets.get_bucket(0).len() == 8); + assert!(buckets.len() == 8); + for i in 1..16 { + assert!(buckets.get_bucket(i).len() == 0); + } + } + + #[test] + fn test_remove() { + let mut buckets = FixedBuckets::new(); + + for i in 0..=255u8 { + let key = Bytes::from(vec![i]); + buckets.insert(get_bucket_index(&key), key); + } + + let key = Bytes::from(vec![0]); + let _ = buckets + .remove_item(get_bucket_index(&key.encode_fixed().unwrap()), &key) + .unwrap(); + let intervals = (0u32..=16) + .map(|i| if i == 0 { 0 } else { i * 16 - 1 }) + .collect::>(); + assert!(buckets.len() == 255); + assert!(intervals == buckets.bucket_lens); + } + + #[test] + fn test_contains() { + let mut buckets = FixedBuckets::new(); + + for i in 0..3u8 { + let key = Bytes::from(vec![i]); + buckets.insert(get_bucket_index(&key), key); + } + + let key = Bytes::from(vec![0]); + assert!(buckets.contains(get_bucket_index(&key.encode_fixed().unwrap()), &key)); + + let key = Bytes::from(vec![5]); + assert!(!buckets.contains(get_bucket_index(&key.encode_fixed().unwrap()), &key)); + + let key = Bytes::from(vec![20]); + assert!(!buckets.contains(get_bucket_index(&key.encode_fixed().unwrap()), &key)); + } +} diff --git a/framework/src/binding/tests/sdk.rs b/framework/src/binding/tests/sdk.rs index fbb58aa14..9282c399a 100644 --- a/framework/src/binding/tests/sdk.rs +++ b/framework/src/binding/tests/sdk.rs @@ -57,7 +57,7 @@ fn test_service_sdk() { let mut it = sdk_map.iter(); assert_eq!( it.next().unwrap(), - (&Hash::digest(Bytes::from("key_1")), Bytes::from("val_1")) + (Hash::digest(Bytes::from("key_1")), Bytes::from("val_1")) ); assert_eq!(it.next().is_none(), true); diff --git a/framework/src/binding/tests/store.rs b/framework/src/binding/tests/store.rs index 6d55161d1..1fff344b7 100644 --- a/framework/src/binding/tests/store.rs +++ b/framework/src/binding/tests/store.rs @@ -92,11 +92,11 @@ fn test_default_store_map() { let mut it = sm.iter(); assert_eq!( it.next().unwrap(), - (&Hash::digest(Bytes::from("key_1")), Bytes::from("val_1")) + (Hash::digest(Bytes::from("key_2")), Bytes::from("val_2")) ); assert_eq!( it.next().unwrap(), - (&Hash::digest(Bytes::from("key_2")), Bytes::from("val_2")) + (Hash::digest(Bytes::from("key_1")), Bytes::from("val_1")) ); assert_eq!(it.next().is_none(), true); } @@ -113,7 +113,13 @@ fn test_default_store_map() { sm.remove(&Hash::digest(Bytes::from("key_1"))).unwrap(); assert_eq!(sm.contains(&Hash::digest(Bytes::from("key_1"))), false); - assert_eq!(sm.len(), 1u32) + assert_eq!(sm.len(), 1u32); + + let sm = DefaultStoreMap::<_, Hash, Bytes>::new(Rc::clone(&rs), "test"); + assert_eq!( + sm.get(&Hash::digest(Bytes::from("key_2"))).unwrap(), + Bytes::from("val_2") + ); } #[test] diff --git a/framework/src/executor/mod.rs b/framework/src/executor/mod.rs index 3b27298e1..5c389eeb7 100644 --- a/framework/src/executor/mod.rs +++ b/framework/src/executor/mod.rs @@ -11,7 +11,7 @@ use std::rc::Rc; use std::sync::Arc; use cita_trie::DB as TrieDB; -use derive_more::{Display, From}; +use derive_more::Display; use bytes::BytesMut; use protocol::traits::{ @@ -409,7 +409,7 @@ impl( + fn alloc_or_recover_map< + Key: 'static + Send + FixedCodec + Clone + PartialEq, + Val: 'static + FixedCodec, + >( &mut self, var_name: &str, ) -> Box>; @@ -210,7 +213,7 @@ pub trait StoreMap { fn is_empty(&self) -> bool; - fn iter<'a>(&'a self) -> Box + 'a>; + fn iter<'a>(&'a self) -> Box + 'a>; } pub trait StoreArray {