From 60ba89e483057e28c3319d87078c147af24cb89a Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 6 Mar 2023 12:47:27 +0900 Subject: [PATCH] Fix deserialization order (#87) * Fix deserialization order Fixes #86 * Add test --- .../src/char_scorer/boundary_tag_scorer.rs | 3 +- vaporetto/src/predictor.rs | 46 +++++++++++++++++++ .../src/type_scorer/boundary_tag_scorer.rs | 3 +- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/vaporetto/src/char_scorer/boundary_tag_scorer.rs b/vaporetto/src/char_scorer/boundary_tag_scorer.rs index 1fa06bc..362eb92 100644 --- a/vaporetto/src/char_scorer/boundary_tag_scorer.rs +++ b/vaporetto/src/char_scorer/boundary_tag_scorer.rs @@ -39,6 +39,7 @@ impl<'de> BorrowDecode<'de> for CharScorerBoundaryTag { let (pma, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(pma_data) }; #[cfg(feature = "charwise-pma")] let (pma, _) = unsafe { CharwiseDoubleArrayAhoCorasick::deserialize_unchecked(pma_data) }; + let weights = Decode::decode(decoder)?; let tag_weight: Vec>> = Decode::decode(decoder)?; let tag_weight = tag_weight .into_iter() @@ -46,7 +47,7 @@ impl<'de> BorrowDecode<'de> for CharScorerBoundaryTag { .collect(); Ok(Self { pma, - weights: Decode::decode(decoder)?, + weights, tag_weight, }) } diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index c57bd1c..1f447cf 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -884,6 +884,52 @@ mod tests { ); } + #[cfg(feature = "tag-prediction")] + #[test] + fn test_serialization_tags() { + let model = create_test_model(); + let predictor = Predictor::new(model, true).unwrap(); + let data = predictor.serialize_to_vec().unwrap(); + let (predictor, _) = unsafe { Predictor::deserialize_from_slice_unchecked(&data).unwrap() }; + let mut sentence = Sentence::from_raw("この人は地球人だ").unwrap(); + predictor.predict(&mut sentence); + sentence.fill_tags(); + assert_eq!(&[-22, 54, 58, 43, -54, 68, 48], sentence.boundary_scores(),); + assert_eq!( + &[ + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary + ], + sentence.boundaries(), + ); + assert_eq!( + &[ + None, + None, + None, + None, + Some(Cow::Borrowed("名詞")), + Some(Cow::Borrowed("ヒト")), + None, + None, + None, + None, + Some(Cow::Borrowed("名詞")), + Some(Cow::Borrowed("チキュー")), + Some(Cow::Borrowed("接尾辞")), + Some(Cow::Borrowed("ジン")), + None, + None, + ], + sentence.tags() + ); + } + #[cfg(feature = "tag-prediction")] #[test] #[should_panic] diff --git a/vaporetto/src/type_scorer/boundary_tag_scorer.rs b/vaporetto/src/type_scorer/boundary_tag_scorer.rs index 8aa4f0d..4509c27 100644 --- a/vaporetto/src/type_scorer/boundary_tag_scorer.rs +++ b/vaporetto/src/type_scorer/boundary_tag_scorer.rs @@ -28,6 +28,7 @@ impl<'de> BorrowDecode<'de> for TypeScorerBoundaryTag { fn borrow_decode>(decoder: &mut D) -> Result { let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; let (pma, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(pma_data) }; + let weights = Decode::decode(decoder)?; let tag_weight: Vec>> = Decode::decode(decoder)?; let tag_weight = tag_weight .into_iter() @@ -35,7 +36,7 @@ impl<'de> BorrowDecode<'de> for TypeScorerBoundaryTag { .collect(); Ok(Self { pma, - weights: Decode::decode(decoder)?, + weights, tag_weight, }) }