Skip to content

Commit

Permalink
Fix deserialization order (#87)
Browse files Browse the repository at this point in the history
* Fix deserialization order

Fixes #86

* Add test
  • Loading branch information
vbkaisetsu authored Mar 6, 2023
1 parent d1c428f commit 60ba89e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
3 changes: 2 additions & 1 deletion vaporetto/src/char_scorer/boundary_tag_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ impl<'de> BorrowDecode<'de> for CharScorerBoundaryTag {
let (pma, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(pma_data) };
#[cfg(feature = "charwise-pma")]
let (pma, _) = unsafe { CharwiseDoubleArrayAhoCorasick::deserialize_unchecked(pma_data) };
let weights = Decode::decode(decoder)?;
let tag_weight: Vec<Vec<Vec<(u32, WeightVector)>>> = Decode::decode(decoder)?;
let tag_weight = tag_weight
.into_iter()
.map(|x| x.into_iter().map(|x| x.into_iter().collect()).collect())
.collect();
Ok(Self {
pma,
weights: Decode::decode(decoder)?,
weights,
tag_weight,
})
}
Expand Down
46 changes: 46 additions & 0 deletions vaporetto/src/predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,52 @@ mod tests {
);
}

#[cfg(feature = "tag-prediction")]
#[test]
fn test_serialization_tags() {
let model = create_test_model();
let predictor = Predictor::new(model, true).unwrap();
let data = predictor.serialize_to_vec().unwrap();
let (predictor, _) = unsafe { Predictor::deserialize_from_slice_unchecked(&data).unwrap() };
let mut sentence = Sentence::from_raw("この人は地球人だ").unwrap();
predictor.predict(&mut sentence);
sentence.fill_tags();
assert_eq!(&[-22, 54, 58, 43, -54, 68, 48], sentence.boundary_scores(),);
assert_eq!(
&[
NotWordBoundary,
WordBoundary,
WordBoundary,
WordBoundary,
NotWordBoundary,
WordBoundary,
WordBoundary
],
sentence.boundaries(),
);
assert_eq!(
&[
None,
None,
None,
None,
Some(Cow::Borrowed("名詞")),
Some(Cow::Borrowed("ヒト")),
None,
None,
None,
None,
Some(Cow::Borrowed("名詞")),
Some(Cow::Borrowed("チキュー")),
Some(Cow::Borrowed("接尾辞")),
Some(Cow::Borrowed("ジン")),
None,
None,
],
sentence.tags()
);
}

#[cfg(feature = "tag-prediction")]
#[test]
#[should_panic]
Expand Down
3 changes: 2 additions & 1 deletion vaporetto/src/type_scorer/boundary_tag_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ impl<'de> BorrowDecode<'de> for TypeScorerBoundaryTag {
fn borrow_decode<D: BorrowDecoder<'de>>(decoder: &mut D) -> Result<Self, DecodeError> {
let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?;
let (pma, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(pma_data) };
let weights = Decode::decode(decoder)?;
let tag_weight: Vec<Vec<Vec<(u32, WeightVector)>>> = Decode::decode(decoder)?;
let tag_weight = tag_weight
.into_iter()
.map(|x| x.into_iter().map(|x| x.into_iter().collect()).collect())
.collect();
Ok(Self {
pma,
weights: Decode::decode(decoder)?,
weights,
tag_weight,
})
}
Expand Down

0 comments on commit 60ba89e

Please sign in to comment.