diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index 816a219..08ebba4 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -13,7 +13,7 @@ keywords = ["japanese", "analyzer", "tokenizer", "morphological"] categories = ["text-processing", "no-std"] [dependencies] -bincode = { version = "2.0.0-rc.2", default-features = false, features = ["alloc", "derive"] } # MIT +bincode = { version = "2.0.0-rc.3", default-features = false, features = ["alloc", "derive"] } # MIT daachorse = "1.0.0" # MIT or Apache-2.0 hashbrown = "0.13.2" # MIT or Apache-2.0 diff --git a/vaporetto/src/char_scorer/boundary_tag_scorer.rs b/vaporetto/src/char_scorer/boundary_tag_scorer.rs index 362eb92..995caea 100644 --- a/vaporetto/src/char_scorer/boundary_tag_scorer.rs +++ b/vaporetto/src/char_scorer/boundary_tag_scorer.rs @@ -11,7 +11,6 @@ use bincode::{ use daachorse::charwise::CharwiseDoubleArrayAhoCorasick; #[cfg(not(feature = "charwise-pma"))] use daachorse::DoubleArrayAhoCorasick; -use hashbrown::HashMap; use crate::char_scorer::CharWeightMerger; use crate::dict_model::DictModel; @@ -19,7 +18,7 @@ use crate::errors::{Result, VaporettoError}; use crate::ngram_model::{NgramModel, TagNgramModel}; use crate::predictor::{PositionalWeight, PositionalWeightWithTag, WeightVector}; use crate::sentence::Sentence; -use crate::utils::SplitMix64Builder; +use crate::utils::{SerializableHashMap, SplitMix64Builder}; pub struct CharScorerBoundaryTag { #[cfg(not(feature = "charwise-pma"))] @@ -27,7 +26,7 @@ pub struct CharScorerBoundaryTag { #[cfg(feature = "charwise-pma")] pma: CharwiseDoubleArrayAhoCorasick, weights: Vec>>, - tag_weight: Vec>>, + tag_weight: Vec>>, } impl<'de> BorrowDecode<'de> for CharScorerBoundaryTag { @@ -40,11 +39,7 @@ impl<'de> BorrowDecode<'de> for CharScorerBoundaryTag { #[cfg(feature = "charwise-pma")] let (pma, _) = unsafe { CharwiseDoubleArrayAhoCorasick::deserialize_unchecked(pma_data) }; let weights = Decode::decode(decoder)?; - let tag_weight: Vec>> = Decode::decode(decoder)?; - let tag_weight = tag_weight - .into_iter() - .map(|x| x.into_iter().map(|x| x.into_iter().collect()).collect()) - .collect(); + let tag_weight = Decode::decode(decoder)?; Ok(Self { pma, weights, @@ -58,12 +53,7 @@ impl Encode for CharScorerBoundaryTag { let pma_data = self.pma.serialize(); Encode::encode(&pma_data, encoder)?; Encode::encode(&self.weights, encoder)?; - let tag_weight: Vec>> = self - .tag_weight - .iter() - .map(|x| x.iter().map(|x| x.iter().collect()).collect()) - .collect(); - Encode::encode(&tag_weight, encoder)?; + Encode::encode(&self.tag_weight, encoder)?; Ok(()) } } @@ -90,11 +80,10 @@ impl CharScorerBoundaryTag { let weight = PositionalWeightWithTag::with_boundary(-word_len, d.weights); merger.add(d.word, weight); } - let mut tag_weight = - vec![ - vec![HashMap::with_hasher(SplitMix64Builder); usize::from(window_size) + 1]; - tag_ngram_model.len() - ]; + let mut tag_weight = vec![ + vec![SerializableHashMap::default(); usize::from(window_size) + 1]; + tag_ngram_model.len() + ]; for (i, tag_model) in tag_ngram_model.into_iter().enumerate() { for d in tag_model.0 { for w in d.weights { diff --git a/vaporetto/src/ngram_model.rs b/vaporetto/src/ngram_model.rs index 6a00e34..1d141ec 100644 --- a/vaporetto/src/ngram_model.rs +++ b/vaporetto/src/ngram_model.rs @@ -9,7 +9,7 @@ pub struct NgramData { } #[derive(Default, Debug, Decode, Encode)] -pub struct NgramModel(pub Vec>); +pub struct NgramModel(pub Vec>); #[derive(Clone, Debug, Decode, Encode)] pub struct TagWeight { @@ -24,4 +24,4 @@ pub struct TagNgramData { } #[derive(Default, Debug, Decode, Encode)] -pub struct TagNgramModel(pub Vec>); +pub struct TagNgramModel(pub Vec>); diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 1f447cf..bb2fa66 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -44,6 +44,12 @@ pub enum WeightVector { Fixed(I32Simd), } +impl Default for WeightVector { + fn default() -> Self { + Self::Variable(vec![]) + } +} + impl Decode for WeightVector { fn decode(decoder: &mut D) -> Result { let weight: Vec = Decode::decode(decoder)?; diff --git a/vaporetto/src/type_scorer/boundary_tag_scorer.rs b/vaporetto/src/type_scorer/boundary_tag_scorer.rs index 4509c27..d0c3eb1 100644 --- a/vaporetto/src/type_scorer/boundary_tag_scorer.rs +++ b/vaporetto/src/type_scorer/boundary_tag_scorer.rs @@ -7,19 +7,18 @@ use bincode::{ BorrowDecode, Decode, Encode, }; use daachorse::DoubleArrayAhoCorasick; -use hashbrown::HashMap; use crate::errors::{Result, VaporettoError}; use crate::ngram_model::{NgramModel, TagNgramModel}; use crate::predictor::{PositionalWeight, PositionalWeightWithTag, WeightVector}; use crate::sentence::Sentence; use crate::type_scorer::TypeWeightMerger; -use crate::utils::SplitMix64Builder; +use crate::utils::{SerializableHashMap, SplitMix64Builder}; pub struct TypeScorerBoundaryTag { pma: DoubleArrayAhoCorasick, weights: Vec>>, - tag_weight: Vec>>, + tag_weight: Vec>>, } impl<'de> BorrowDecode<'de> for TypeScorerBoundaryTag { @@ -29,11 +28,7 @@ impl<'de> BorrowDecode<'de> for TypeScorerBoundaryTag { let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; let (pma, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(pma_data) }; let weights = Decode::decode(decoder)?; - let tag_weight: Vec>> = Decode::decode(decoder)?; - let tag_weight = tag_weight - .into_iter() - .map(|x| x.into_iter().map(|x| x.into_iter().collect()).collect()) - .collect(); + let tag_weight = Decode::decode(decoder)?; Ok(Self { pma, weights, @@ -47,12 +42,7 @@ impl Encode for TypeScorerBoundaryTag { let pma_data = self.pma.serialize(); Encode::encode(&pma_data, encoder)?; Encode::encode(&self.weights, encoder)?; - let tag_weight: Vec>> = self - .tag_weight - .iter() - .map(|x| x.iter().map(|x| x.iter().collect()).collect()) - .collect(); - Encode::encode(&tag_weight, encoder)?; + Encode::encode(&self.tag_weight, encoder)?; Ok(()) } } @@ -68,11 +58,10 @@ impl TypeScorerBoundaryTag { let weight = PositionalWeightWithTag::with_boundary(-i16::from(window_size), d.weights); merger.add(d.ngram, weight); } - let mut tag_weight = - vec![ - vec![HashMap::with_hasher(SplitMix64Builder); usize::from(window_size) + 1]; - tag_ngram_model.len() - ]; + let mut tag_weight = vec![ + vec![SerializableHashMap::default(); usize::from(window_size) + 1]; + tag_ngram_model.len() + ]; for (i, tag_model) in tag_ngram_model.into_iter().enumerate() { for d in tag_model.0 { for w in d.weights { diff --git a/vaporetto/src/utils.rs b/vaporetto/src/utils.rs index 2d55b5f..51e3f11 100644 --- a/vaporetto/src/utils.rs +++ b/vaporetto/src/utils.rs @@ -12,7 +12,7 @@ use bincode::{ error::{DecodeError, EncodeError}, Decode, Encode, }; -use hashbrown::HashMap; +use hashbrown::{hash_map::DefaultHashBuilder, HashMap}; #[cfg(feature = "fix-weight-length")] #[inline(always)] @@ -35,42 +35,53 @@ impl Writer for VecWriter { } } -#[derive(Debug)] -pub struct SerializableHashMap(pub HashMap); +#[derive(Clone, Debug, Default)] +pub struct SerializableHashMap(pub HashMap); -impl Deref for SerializableHashMap { - type Target = HashMap; +impl Deref for SerializableHashMap { + type Target = HashMap; fn deref(&self) -> &Self::Target { &self.0 } } -impl DerefMut for SerializableHashMap { +impl DerefMut for SerializableHashMap { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } -impl Decode for SerializableHashMap +impl Decode for SerializableHashMap where - K: Encode + Decode + Eq + Hash, - V: Encode + Decode, + K: Decode + Eq + Hash, + V: Decode, + S: BuildHasher + Default, { fn decode(decoder: &mut D) -> Result { - let raw: Vec<(K, V)> = Decode::decode(decoder)?; - Ok(Self(raw.into_iter().collect())) + let mut result = HashMap::with_hasher(S::default()); + let size: u64 = Decode::decode(decoder)?; + for _ in 0..size { + let k = Decode::decode(decoder)?; + let v = Decode::decode(decoder)?; + result.insert(k, v); + } + Ok(Self(result)) } } -impl Encode for SerializableHashMap +impl Encode for SerializableHashMap where - K: Encode + Decode, - V: Encode + Decode, + K: Encode, + V: Encode, { fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { - let raw: Vec<(&K, &V)> = self.0.iter().collect(); - Encode::encode(&raw, encoder)?; + let size = u64::try_from(self.0.len()).unwrap(); + Encode::encode(&size, encoder)?; + for (k, v) in &self.0 { + Encode::encode(k, encoder)?; + Encode::encode(v, encoder)?; + } Ok(()) } }