diff --git a/Cargo.toml b/Cargo.toml index 3c5b1193..0a60007c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "vaporetto", "vaporetto_rules", + "vaporetto_tantivy", "manipulate_model", "predict", "train", diff --git a/evaluate/src/main.rs b/evaluate/src/main.rs index dd0d4b4c..99a28e8e 100644 --- a/evaluate/src/main.rs +++ b/evaluate/src/main.rs @@ -83,13 +83,11 @@ struct Opt { fn main() -> Result<(), Box> { let opt = Opt::from_args(); - let fullwidth_filter = KyteaFullwidthFilter::new(); + let fullwidth_filter = KyteaFullwidthFilter; let mut post_filters: Vec> = vec![]; for wsconst in &opt.wsconst { match wsconst { - WsConst::GraphemeCluster => { - post_filters.push(Box::new(ConcatGraphemeClustersFilter::new())) - } + WsConst::GraphemeCluster => post_filters.push(Box::new(ConcatGraphemeClustersFilter)), WsConst::CharType(char_type) => { post_filters.push(Box::new(KyteaWsConstFilter::new(*char_type))) } diff --git a/manipulate_model/Cargo.toml b/manipulate_model/Cargo.toml index 5139cfd9..3c27d5b9 100644 --- a/manipulate_model/Cargo.toml +++ b/manipulate_model/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2018" [dependencies] -csv = "1.1" # Unlicense OR MIT +csv = "1.1" # Unlicense or MIT serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0 structopt = "0.3" # MIT or Apache-2.0 vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0 diff --git a/predict/src/main.rs b/predict/src/main.rs index 3be29879..a6d4f9be 100644 --- a/predict/src/main.rs +++ b/predict/src/main.rs @@ -110,14 +110,12 @@ fn main() -> Result<(), Box> { let mut pre_filters: Vec> = vec![]; if !opt.no_norm { - pre_filters.push(Box::new(KyteaFullwidthFilter::new())); + pre_filters.push(Box::new(KyteaFullwidthFilter)); } let mut post_filters: Vec> = vec![]; for wsconst in &opt.wsconst { match wsconst { - WsConst::GraphemeCluster => { - post_filters.push(Box::new(ConcatGraphemeClustersFilter::new())) - } + WsConst::GraphemeCluster => post_filters.push(Box::new(ConcatGraphemeClustersFilter)), WsConst::CharType(char_type) => { post_filters.push(Box::new(KyteaWsConstFilter::new(*char_type))) } diff --git a/train/src/main.rs b/train/src/main.rs index 44f5cd5d..6d60f579 100644 --- a/train/src/main.rs +++ b/train/src/main.rs @@ -70,7 +70,7 @@ struct Opt { fn main() -> Result<(), Box> { let opt = Opt::from_args(); - let fullwidth_filter = KyteaFullwidthFilter::new(); + let fullwidth_filter = KyteaFullwidthFilter; eprintln!("Loading dataset..."); let mut train_sents = vec![]; diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index 02680225..fcbaffc4 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -10,7 +10,6 @@ repository = "https://github.com/legalforce-research/vaporetto" readme = "README.md" keywords = ["japanese", "analyzer", "tokenizer", "morphological"] categories = ["text-processing"] -autotests = false [dependencies] daachorse = "0.4.0" # MIT or Apache-2.0 diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 18bd0d7a..e5bc3f92 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -1,5 +1,5 @@ use std::iter; -use std::rc::Rc; +use std::sync::Arc; use daachorse::DoubleArrayAhoCorasick; @@ -148,7 +148,7 @@ impl NaiveWeightSet { boundary: None, tag_left: None, tag_right: None, - tag_self: Some(Rc::new(vec![TagRangeScore::new( + tag_self: Some(Arc::new(vec![TagRangeScore::new( start_rel_position, weight, )])), @@ -171,7 +171,7 @@ impl MergableWeight for NaiveWeightSet { tag_self: utils::xor_or_zip_with(&weight1.tag_self, &weight2.tag_self, |w1, w2| { let mut w = w1.to_vec(); w.append(&mut w2.to_vec()); - Rc::new(w) + Arc::new(w) }), } } @@ -345,7 +345,7 @@ impl CharScorerWithTags { .add_weight(&mut tag_ys.right_scores, offset); } if let Some(weight) = weight_set.tag_self.as_ref() { - tag_ys.self_scores[m_end - 1].replace(Rc::clone(weight)); + tag_ys.self_scores[m_end - 1].replace(Arc::clone(weight)); } } } diff --git a/vaporetto/src/feature.rs b/vaporetto/src/feature.rs index f8eecdc8..f1f80bb1 100644 --- a/vaporetto/src/feature.rs +++ b/vaporetto/src/feature.rs @@ -1,5 +1,5 @@ use std::hash::Hash; -use std::rc::Rc; +use std::sync::Arc; use daachorse::DoubleArrayAhoCorasick; @@ -213,7 +213,7 @@ impl<'a> TagFeature<'a> { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct TagExample<'a> { pub features: Vec>, - pub tag: Rc, + pub tag: Arc, } pub struct TagExampleGenerator { @@ -240,8 +240,11 @@ impl TagExampleGenerator { sentence.char_substring(start, sentence.chars.len()), )); } - let mut current_tag: Option> = - sentence.tags.last().and_then(|x| x.as_ref()).map(Rc::clone); + let mut current_tag: Option> = sentence + .tags + .last() + .and_then(|x| x.as_ref()) + .map(Arc::clone); let mut tag_right_pos = sentence.chars.len(); for (i, (t, b)) in sentence .tags @@ -279,7 +282,7 @@ impl TagExampleGenerator { features = vec![]; } if let Some(tag) = t.as_ref() { - current_tag.replace(Rc::clone(tag)); + current_tag.replace(Arc::clone(tag)); tag_right_pos = i + 1; for j in (i + 2)..(i + 2 + self.char_window_size).min(sentence.chars.len() + 1) @@ -479,7 +482,7 @@ mod tests { TagFeature::left_char_ngram_bos(-1, "Ar"), TagFeature::chars("Aria"), ], - tag: Rc::new("名詞".to_string()), + tag: Arc::new("名詞".to_string()), }, TagExample { features: vec![ @@ -503,7 +506,7 @@ mod tests { TagFeature::left_char_ngram(-1, "aは火"), TagFeature::chars("は"), ], - tag: Rc::new("助詞".to_string()), + tag: Arc::new("助詞".to_string()), }, TagExample { features: vec![ @@ -520,7 +523,7 @@ mod tests { TagFeature::left_char_ngram(-1, "猫だ"), TagFeature::chars("だ"), ], - tag: Rc::new("助動詞".to_string()), + tag: Arc::new("助動詞".to_string()), }, ]; @@ -560,7 +563,7 @@ mod tests { TagFeature::left_char_ngram_bos(-1, "Ar"), TagFeature::chars("Aria"), ], - tag: Rc::new("名詞".to_string()), + tag: Arc::new("名詞".to_string()), }, TagExample { features: vec![ @@ -578,7 +581,7 @@ mod tests { TagFeature::left_char_ngram(-1, "aは火"), TagFeature::chars("は"), ], - tag: Rc::new("助詞".to_string()), + tag: Arc::new("助詞".to_string()), }, TagExample { features: vec![ @@ -592,7 +595,7 @@ mod tests { TagFeature::left_char_ngram(-1, "猫だ"), TagFeature::chars("だ"), ], - tag: Rc::new("助動詞".to_string()), + tag: Arc::new("助動詞".to_string()), }, ]; @@ -631,7 +634,7 @@ mod tests { TagFeature::left_char_ngram_bos(-1, "A"), TagFeature::chars("Aria"), ], - tag: Rc::new("名詞".to_string()), + tag: Arc::new("名詞".to_string()), }, TagExample { features: vec![ @@ -649,7 +652,7 @@ mod tests { TagFeature::left_char_ngram(-1, "aは"), TagFeature::chars("は"), ], - tag: Rc::new("助詞".to_string()), + tag: Arc::new("助詞".to_string()), }, TagExample { features: vec![ @@ -663,7 +666,7 @@ mod tests { TagFeature::left_char_ngram(-1, "猫だ"), TagFeature::chars("だ"), ], - tag: Rc::new("助動詞".to_string()), + tag: Arc::new("助動詞".to_string()), }, ]; @@ -704,7 +707,7 @@ mod tests { TagFeature::left_char_ngram_bos(-1, "僕は"), TagFeature::chars("僕"), ], - tag: Rc::new("代名詞".to_string()), + tag: Arc::new("代名詞".to_string()), }, TagExample { features: vec![ @@ -725,7 +728,7 @@ mod tests { TagFeature::left_char_ngram(-1, "僕"), TagFeature::chars("は"), ], - tag: Rc::new("助詞".to_string()), + tag: Arc::new("助詞".to_string()), }, TagExample { features: vec![ @@ -743,7 +746,7 @@ mod tests { TagFeature::left_char_ngram(-1, "は"), TagFeature::chars("人間"), ], - tag: Rc::new("名詞".to_string()), + tag: Arc::new("名詞".to_string()), }, ]; diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 1a28804a..2fde23c5 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -1,7 +1,7 @@ use std::mem; use std::cmp::Ordering; -use std::rc::Rc; +use std::sync::Arc; use crate::char_scorer::{self, CharScorer, CharScorerWithTags}; use crate::errors::Result; @@ -24,7 +24,7 @@ pub struct Predictor { padding: usize, // for tag prediction - tag_names: Vec>, + tag_names: Vec>, tag_bias: Vec, } @@ -45,7 +45,7 @@ impl Predictor { let char_scorer = if predict_tags { for cls in model.tag_model.class_info { - tag_names.push(Rc::new(cls.name)); + tag_names.push(Arc::new(cls.name)); tag_bias.push(cls.bias); } CharScorerWrapper::BoundaryAndTags(CharScorerWithTags::new( @@ -142,8 +142,8 @@ impl Predictor { sentence } - fn best_tag(&self, scores: &[i32]) -> Rc { - Rc::clone( + fn best_tag(&self, scores: &[i32]) -> Arc { + Arc::clone( scores .iter() .zip(&self.tag_names) diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index f2908fe7..a3c0cf41 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::sync::Arc; use crate::errors::{Result, VaporettoError}; @@ -110,7 +110,7 @@ impl TagRangeScore { } } -pub type TagRangeScores = Rc>; +pub type TagRangeScores = Arc>; #[derive(Debug, PartialEq, Clone, Default)] pub struct TagScores { @@ -153,7 +153,7 @@ pub struct Sentence { pub(crate) boundaries: Vec, pub(crate) boundary_scores: Vec, pub(crate) tag_scores: TagScores, - pub(crate) tags: Vec>>, + pub(crate) tags: Vec>>, } impl Sentence { @@ -161,7 +161,7 @@ impl Sentence { text: String, chars: Vec, boundaries: Vec, - tags: Vec>>, + tags: Vec>>, ) -> Self { let mut s = Self { text, @@ -202,7 +202,7 @@ impl Sentence { raw_text: &str, chars: &mut Vec, boundaries: &mut Vec, - tags: &mut Vec>>, + tags: &mut Vec>>, ) -> Result<()> { if raw_text.is_empty() { return Err(VaporettoError::invalid_argument( @@ -235,7 +235,7 @@ impl Sentence { text: &mut String, chars: &mut Vec, boundaries: &mut Vec, - tags: &mut Vec>>, + tags: &mut Vec>>, ) -> Result<()> { if tokenized_text.is_empty() { return Err(VaporettoError::invalid_argument( @@ -305,7 +305,7 @@ impl Sentence { } else { BoundaryType::NotWordBoundary }); - tags.push(tag_str.take().map(Rc::new)); + tags.push(tag_str.take().map(Arc::new)); } if c == '\0' { return Err(VaporettoError::invalid_argument( @@ -327,7 +327,7 @@ impl Sentence { "must not end with a whitespace", )); } - tags.push(tag_str_tmp.take().map(Rc::new)); + tags.push(tag_str_tmp.take().map(Arc::new)); Ok(()) } @@ -337,7 +337,7 @@ impl Sentence { text: &mut String, chars: &mut Vec, boundaries: &mut Vec, - tags: &mut Vec>>, + tags: &mut Vec>>, ) -> Result<()> { if labeled_text.is_empty() { return Err(VaporettoError::invalid_argument( @@ -391,7 +391,7 @@ impl Sentence { "POS tag must be annotated to a token".to_string(), )); } - tags.push(tag_str.take().map(Rc::new)); + tags.push(tag_str.take().map(Arc::new)); boundaries.push(BoundaryType::WordBoundary); is_char = true; fixed_token = true; @@ -424,7 +424,7 @@ impl Sentence { } } } - tags.push(tag_str.take().map(Rc::new)); + tags.push(tag_str.take().map(Arc::new)); if chars.len() != boundaries.len() + 1 { return Err(VaporettoError::invalid_argument( "labeled_text", @@ -1031,23 +1031,23 @@ impl Sentence { /// # Examples /// /// ``` - /// use std::rc::Rc; + /// use std::sync::Arc; /// /// use vaporetto::{BoundaryType, Sentence}; /// /// let s = Sentence::from_tokenized("I/PRP am a/DT cat/NN ./.").unwrap(); /// assert_eq!(&[ - /// Some(Rc::new("PRP".to_string())), // 'I' + /// Some(Arc::new("PRP".to_string())), // 'I' /// None, // 'a' /// None, // 'm' - /// Some(Rc::new("DT".to_string())), // 'a' + /// Some(Arc::new("DT".to_string())), // 'a' /// None, // 'c' /// None, // 'a' - /// Some(Rc::new("NN".to_string())), // 't' - /// Some(Rc::new(".".to_string())), // '.' + /// Some(Arc::new("NN".to_string())), // 't' + /// Some(Arc::new(".".to_string())), // '.' /// ], s.tags()); /// ``` - pub fn tags(&self) -> &[Option>] { + pub fn tags(&self) -> &[Option>] { &self.tags } @@ -1056,7 +1056,7 @@ impl Sentence { /// # Returns /// /// A mutable reference to the part-of-speech information. - pub fn tags_mut(&mut self) -> &mut [Option>] { + pub fn tags_mut(&mut self) -> &mut [Option>] { &mut self.tags } @@ -1078,6 +1078,34 @@ impl Sentence { &self.chars } + /// Gets immutable references to the characters and character types, and a mutable reference to + /// boundaries. + /// + /// # Returns + /// + /// A tuple of references. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::{BoundaryType, Sentence}; + /// + /// let mut s = Sentence::from_partial_annotation("A-1|あ エ-漢|?").unwrap(); + /// let (chars, char_types, boundaries) = s.chars_and_boundaries_mut(); + /// assert_eq!(&['A', '1', 'あ', 'エ', '漢', '?'], chars); + /// assert_eq!(&[b'R', b'D', b'H', b'T', b'K', b'O'], char_types); + /// assert_eq!(&[ + /// BoundaryType::NotWordBoundary, + /// BoundaryType::WordBoundary, + /// BoundaryType::Unknown, + /// BoundaryType::NotWordBoundary, + /// BoundaryType::WordBoundary, + /// ], boundaries); + /// ``` + pub fn chars_and_boundaries_mut(&mut self) -> (&[char], &[u8], &mut [BoundaryType]) { + (&self.chars, &self.char_type, &mut self.boundaries) + } + /// Gets a reference to the character type information. /// /// # Returns @@ -1090,7 +1118,7 @@ impl Sentence { /// use vaporetto::Sentence; /// /// let s = Sentence::from_raw("A1あエ漢?").unwrap(); - /// assert_eq!(&[b'R', b'D', b'H', b'T', b'K', b'O',], s.char_types()); + /// assert_eq!(&[b'R', b'D', b'H', b'T', b'K', b'O'], s.char_types()); /// ``` pub fn char_types(&self) -> &[u8] { &self.char_type @@ -1601,10 +1629,10 @@ mod tests { None, None, None, - Some(Rc::new("名詞".to_string())), + Some(Arc::new("名詞".to_string())), None, None, - Some(Rc::new("形容詞".to_string())), + Some(Arc::new("形容詞".to_string())), None, None, None, @@ -1615,7 +1643,7 @@ mod tests { None, None, None, - Some(Rc::new("補助記号".to_string())), + Some(Arc::new("補助記号".to_string())), ], }; assert_eq!(expected, s.unwrap()); @@ -1712,10 +1740,10 @@ mod tests { None, None, None, - Some(Rc::new("名詞".to_string())), + Some(Arc::new("名詞".to_string())), None, None, - Some(Rc::new("形容詞".to_string())), + Some(Arc::new("形容詞".to_string())), None, None, None, @@ -1726,7 +1754,7 @@ mod tests { None, None, None, - Some(Rc::new("補助記号".to_string())), + Some(Arc::new("補助記号".to_string())), ], }; assert_eq!(expected, s); diff --git a/vaporetto_rules/Cargo.toml b/vaporetto_rules/Cargo.toml index df653066..2233f39f 100644 --- a/vaporetto_rules/Cargo.toml +++ b/vaporetto_rules/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vaporetto_rules" -version = "0.1.5" +version = "0.3.0" edition = "2018" authors = ["Koichi Akabe "] description = "Rule-base filters for Vaporetto" @@ -10,8 +10,7 @@ repository = "https://github.com/legalforce-research/vaporetto" readme = "README.md" keywords = ["japanese", "analyzer", "tokenizer", "morphological"] categories = ["text-processing"] -autotests = false [dependencies] -unicode-segmentation = "1.8.0" # MIT or Apache-2.0 +unicode-segmentation = "1.9.0" # MIT or Apache-2.0 vaporetto = { path = "../vaporetto", version = "0.3.0" } # MIT or Apache-2.0 diff --git a/vaporetto_rules/src/lib.rs b/vaporetto_rules/src/lib.rs index 6cf9de17..fb3585f5 100644 --- a/vaporetto_rules/src/lib.rs +++ b/vaporetto_rules/src/lib.rs @@ -21,10 +21,10 @@ //! let mut predictor = Predictor::new(model, false).unwrap(); //! //! let pre_filters: Vec> = vec![ -//! Box::new(KyteaFullwidthFilter::new()), +//! Box::new(KyteaFullwidthFilter), //! ]; //! let post_filters: Vec> = vec![ -//! Box::new(ConcatGraphemeClustersFilter::new()), +//! Box::new(ConcatGraphemeClustersFilter), //! Box::new(KyteaWsConstFilter::new(CharacterType::Digit)), //! ]; //! @@ -52,7 +52,7 @@ pub mod string_filters; use vaporetto::Sentence; -pub trait SentenceFilter { +pub trait SentenceFilter: Send + Sync { /// Filter a specified sentence using rules. /// /// # Arguments: @@ -65,7 +65,7 @@ pub trait SentenceFilter { fn filter(&self, sentence: Sentence) -> Sentence; } -pub trait StringFilter { +pub trait StringFilter: Send + Sync { /// Filter a specified string using rules. /// /// # Arguments: diff --git a/vaporetto_rules/src/sentence_filters.rs b/vaporetto_rules/src/sentence_filters.rs index cd968e80..b701ec2a 100644 --- a/vaporetto_rules/src/sentence_filters.rs +++ b/vaporetto_rules/src/sentence_filters.rs @@ -2,6 +2,8 @@ mod concat_grapheme_clusters; mod kytea_wsconst; +mod split_linebreaks; pub use concat_grapheme_clusters::ConcatGraphemeClustersFilter; pub use kytea_wsconst::KyteaWsConstFilter; +pub use split_linebreaks::SplitLinebreaksFilter; diff --git a/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs b/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs index 287ae38b..39d149b2 100644 --- a/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs +++ b/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs @@ -4,47 +4,18 @@ use vaporetto::{BoundaryType, Sentence}; use crate::SentenceFilter; /// Grapheme cluster concatenator. +#[derive(Clone, Default)] pub struct ConcatGraphemeClustersFilter; -impl ConcatGraphemeClustersFilter { - /// Creates a new ConcatGraphemeClustersFilter. - /// - /// # Returns - /// - /// A new ConcatGraphemeClustersFilter. - pub const fn new() -> Self { - Self {} - } -} - -impl Default for ConcatGraphemeClustersFilter { - fn default() -> Self { - Self::new() - } -} - impl SentenceFilter for ConcatGraphemeClustersFilter { - /// Concatenates grapheme clusters. - /// - /// # Arguments: - /// - /// * `sentence` - Input sentence. - /// - /// # Returns - /// - /// A processed sentence. fn filter(&self, mut sentence: Sentence) -> Sentence { let mut tmp = sentence.boundaries().to_vec(); - for (i, c) in UnicodeSegmentation::grapheme_indices(sentence.to_raw_string(), true) { + for (i, c) in sentence.to_raw_string().grapheme_indices(true) { let start = sentence.get_char_pos(i).unwrap(); let end = sentence.get_char_pos(i + c.len()).unwrap() - 1; - for b in &mut tmp[start..end] { - *b = BoundaryType::NotWordBoundary; - } - } - for (b, t) in sentence.boundaries_mut().iter_mut().zip(&tmp) { - *b = *t; + tmp[start..end].fill(BoundaryType::NotWordBoundary); } + sentence.boundaries_mut().copy_from_slice(&tmp); sentence } } @@ -56,7 +27,7 @@ mod tests { #[test] fn test_concat_grapheme_clusters_no_boundary() { let s = Sentence::from_tokenized("\u{200d}").unwrap(); - let filter = ConcatGraphemeClustersFilter::new(); + let filter = ConcatGraphemeClustersFilter; let s = filter.filter(s); assert_eq!("\u{200d}", s.to_tokenized_string().unwrap()); } @@ -65,7 +36,7 @@ mod tests { fn test_concat_grapheme_clusters_zwj() { let s = Sentence::from_tokenized("\u{1f468} \u{200d} \u{1f469} \u{200d} \u{1f466}").unwrap(); - let filter = ConcatGraphemeClustersFilter::new(); + let filter = ConcatGraphemeClustersFilter; let s = filter.filter(s); assert_eq!( "\u{1f468}\u{200d}\u{1f469}\u{200d}\u{1f466}", @@ -76,7 +47,7 @@ mod tests { #[test] fn test_concat_grapheme_clusters_color() { let s = Sentence::from_tokenized("\u{1f44f} \u{1f3fd}").unwrap(); - let filter = ConcatGraphemeClustersFilter::new(); + let filter = ConcatGraphemeClustersFilter; let s = filter.filter(s); assert_eq!("\u{1f44f}\u{1f3fd}", s.to_tokenized_string().unwrap()); } @@ -84,7 +55,7 @@ mod tests { #[test] fn test_concat_grapheme_clusters_combined() { let s = Sentence::from_tokenized("これ は 手 \u{1f44f} \u{1f3fd} で す").unwrap(); - let filter = ConcatGraphemeClustersFilter::new(); + let filter = ConcatGraphemeClustersFilter; let s = filter.filter(s); assert_eq!( "これ は 手 \u{1f44f}\u{1f3fd} で す", diff --git a/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs b/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs index bd0d1318..07d69964 100644 --- a/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs +++ b/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs @@ -3,6 +3,7 @@ use vaporetto::{BoundaryType, CharacterType, Sentence}; use crate::SentenceFilter; /// Character type concatenator. This filter works like KyTea's wsconst option. +#[derive(Clone)] pub struct KyteaWsConstFilter { char_type: CharacterType, } @@ -23,26 +24,14 @@ impl KyteaWsConstFilter { } impl SentenceFilter for KyteaWsConstFilter { - /// Concatenates consecutive character types. - /// - /// # Arguments: - /// - /// * `sentence` - Input sentence. - /// - /// # Returns - /// - /// A processed sentence. fn filter(&self, mut sentence: Sentence) -> Sentence { let t_flag = self.char_type as u8; - let mut tmp = sentence.boundaries().to_vec(); - for (i, (b, &t)) in tmp.iter_mut().zip(sentence.char_types()).enumerate() { - if t == t_flag && t == sentence.char_types()[i + 1] { + let (_, char_types, boundaries) = sentence.chars_and_boundaries_mut(); + for ((t1, t2), b) in char_types.iter().zip(&char_types[1..]).zip(boundaries) { + if *t1 == t_flag && *t2 == t_flag { *b = BoundaryType::NotWordBoundary; } } - for (b, t) in sentence.boundaries_mut().iter_mut().zip(&tmp) { - *b = *t; - } sentence } } diff --git a/vaporetto_rules/src/sentence_filters/split_linebreaks.rs b/vaporetto_rules/src/sentence_filters/split_linebreaks.rs new file mode 100644 index 00000000..71156946 --- /dev/null +++ b/vaporetto_rules/src/sentence_filters/split_linebreaks.rs @@ -0,0 +1,51 @@ +use vaporetto::{BoundaryType, Sentence}; + +use crate::SentenceFilter; + +/// Line breaks splitter. +#[derive(Clone, Default)] +pub struct SplitLinebreaksFilter; + +impl SentenceFilter for SplitLinebreaksFilter { + fn filter(&self, mut sentence: Sentence) -> Sentence { + let (chars, _, boundaries) = sentence.chars_and_boundaries_mut(); + for ((c1, c2), b) in chars.iter().zip(&chars[1..]).zip(boundaries) { + match (*c1, *c2) { + ('\r' | '\n', _) | (_, '\r' | '\n') => { + *b = BoundaryType::WordBoundary; + } + _ => {} + } + } + sentence + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_lf() { + let s = Sentence::from_tokenized("前の行\n次の行").unwrap(); + let filter = SplitLinebreaksFilter; + let s = filter.filter(s); + assert_eq!("前の行 \n 次の行", s.to_tokenized_string().unwrap()); + } + + #[test] + fn test_split_cr() { + let s = Sentence::from_tokenized("前の行\r次の行").unwrap(); + let filter = SplitLinebreaksFilter; + let s = filter.filter(s); + assert_eq!("前の行 \r 次の行", s.to_tokenized_string().unwrap()); + } + + #[test] + fn test_split_crlf() { + let s = Sentence::from_tokenized("前の行\r\n次の行").unwrap(); + let filter = SplitLinebreaksFilter; + let s = filter.filter(s); + assert_eq!("前の行 \r \n 次の行", s.to_tokenized_string().unwrap()); + } +} diff --git a/vaporetto_rules/src/string_filters/kytea_fullwidth.rs b/vaporetto_rules/src/string_filters/kytea_fullwidth.rs index 3dc841fc..abefc99d 100644 --- a/vaporetto_rules/src/string_filters/kytea_fullwidth.rs +++ b/vaporetto_rules/src/string_filters/kytea_fullwidth.rs @@ -1,35 +1,10 @@ use crate::StringFilter; /// Half-width to full-width filter. This filter works like KyTea's preprocessor. +#[derive(Clone, Default)] pub struct KyteaFullwidthFilter; -impl KyteaFullwidthFilter { - /// Creates a new KyteaFullwidthFilter. - /// - /// # Returns - /// - /// A new KyteaFullwidthFilter. - pub const fn new() -> Self { - Self {} - } -} - -impl Default for KyteaFullwidthFilter { - fn default() -> Self { - Self::new() - } -} - impl StringFilter for KyteaFullwidthFilter { - /// Replace alphanumerics and symbols to full-width characters. - /// - /// # Arguments: - /// - /// * `text` - Input text. - /// - /// # Returns - /// - /// A processed text. fn filter(&self, string: &str) -> String { let mut chars: Vec<_> = string.chars().collect(); for c in &mut chars { diff --git a/vaporetto_tantivy/Cargo.toml b/vaporetto_tantivy/Cargo.toml new file mode 100644 index 00000000..4694b884 --- /dev/null +++ b/vaporetto_tantivy/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "vaporetto_tantivy" +version = "0.3.0" +edition = "2021" +authors = ["Koichi Akabe "] +description = "Vaporetto Tokenizer for Tantivy" +license = "MIT OR Apache-2.0" +homepage = "https://github.com/legalforce-research/vaporetto" +repository = "https://github.com/legalforce-research/vaporetto" +readme = "README.md" +keywords = ["japanese", "tokenizer", "tantivy"] +categories = ["text-processing"] + +[dependencies] +vaporetto = { path = "../vaporetto", version = "0.3.0" } # MIT or Apache-2.0 +vaporetto_rules = { path = "../vaporetto_rules", version = "0.3.0" } # MIT or Apache-2.0 +tantivy = "0.16" # MIT + +[dev-dependencies] +ruzstd = "0.2.4" # MIT diff --git a/vaporetto_tantivy/README.md b/vaporetto_tantivy/README.md new file mode 100644 index 00000000..acbb8d73 --- /dev/null +++ b/vaporetto_tantivy/README.md @@ -0,0 +1,40 @@ +# vaporetto_tantivy + +Vaporetto is a fast and lightweight pointwise prediction based tokenizer. +vaporetto_tantivy is a crate to use Vaporetto in [Tantivy](https://github.com/quickwit-oss/tantivy). + +# Example + +```rust +use std::fs::File; +use std::io::{Read, BufReader}; + +use tantivy::schema::{IndexRecordOption, Schema, TextFieldIndexing, TextOptions}; +use tantivy::Index; +use vaporetto::Model; +use vaporetto_tantivy::VaporettoTokenizer; + +let mut schema_builder = Schema::builder(); +let text_field_indexing = TextFieldIndexing::default() + .set_tokenizer("ja_vaporetto") + .set_index_option(IndexRecordOption::WithFreqsAndPositions); +let text_options = TextOptions::default() + .set_indexing_options(text_field_indexing) + .set_stored(); +schema_builder.add_text_field("title", text_options); +let schema = schema_builder.build(); +let index = Index::create_in_ram(schema); + +// Loads a model with decompression. +let mut f = BufReader::new(File::open("bccwj-suw+unidic.model.zst").unwrap()); +let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap(); +let mut buff = vec![]; +decoder.read_to_end(&mut buff).unwrap(); +let model = Model::read(&mut buff.as_slice()).unwrap(); + +// Creates VaporettoTokenizer with wsconst=DGR. +let tokenizer = VaporettoTokenizer::new(model, "DGR").unwrap(); +index + .tokenizers() + .register("ja_vaporetto", tokenizer); +``` diff --git a/vaporetto_tantivy/src/lib.rs b/vaporetto_tantivy/src/lib.rs new file mode 100644 index 00000000..ec2f375e --- /dev/null +++ b/vaporetto_tantivy/src/lib.rs @@ -0,0 +1,448 @@ +//! # vaporetto_tantivy +//! +//! Vaporetto Tokenizer for Tantivy +//! +//! ## Examples +//! +//! ```no_run +//! use std::fs::File; +//! use std::io::{Read, BufReader}; +//! +//! use tantivy::tokenizer::Tokenizer; +//! use vaporetto::Model; +//! use vaporetto_tantivy::VaporettoTokenizer; +//! +//! let mut f = BufReader::new(File::open("model.zst").unwrap()); +//! let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap(); +//! let mut buff = vec![]; +//! decoder.read_to_end(&mut buff).unwrap(); +//! let model = Model::read(&mut buff.as_slice()).unwrap(); +//! +//! let tokenizer = VaporettoTokenizer::new(model, "DGR").unwrap(); +//! +//! let mut stream = tokenizer.token_stream("東京特許許可局"); +//! +//! let token = stream.next().unwrap(); +//! assert_eq!(token.text, "東京"); +//! assert_eq!(token.offset_from, 0); +//! assert_eq!(token.offset_to, 6); +//! assert_eq!(token.position, 0); +//! +//! let token = stream.next().unwrap(); +//! assert_eq!(token.text, "特許"); +//! assert_eq!(token.offset_from, 6); +//! assert_eq!(token.offset_to, 12); +//! assert_eq!(token.position, 1); +//! +//! let token = stream.next().unwrap(); +//! assert_eq!(token.text, "許可"); +//! assert_eq!(token.offset_from, 12); +//! assert_eq!(token.offset_to, 18); +//! assert_eq!(token.position, 2); +//! +//! let token = stream.next().unwrap(); +//! assert_eq!(token.text, "局"); +//! assert_eq!(token.offset_from, 18); +//! assert_eq!(token.offset_to, 21); +//! assert_eq!(token.position, 3); +//! +//! assert!(stream.next().is_none()); +/// ``` +use std::sync::Arc; + +use tantivy::tokenizer::{BoxTokenStream, Token, TokenStream, Tokenizer}; +use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence}; +use vaporetto_rules::{ + sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter, SplitLinebreaksFilter}, + string_filters::KyteaFullwidthFilter, + SentenceFilter, StringFilter, +}; + +/// Tokenize the text using Vaporetto. +#[derive(Clone)] +pub struct VaporettoTokenizer { + predictor: Arc, + prefilter: KyteaFullwidthFilter, + postfilters: Vec>, +} + +impl VaporettoTokenizer { + /// Creates a new VaporettoTokenizer. + /// + /// # Arguments + /// + /// * `model` - A model data of Vaporetto. + /// * `wsconst` - Character types that the tokenizer does not segment. + /// D: Digit, R: Roman, H: Hiragana, T: Katakana, K: Kanji, O: Other, + /// G: Grapheme cluster. + /// + /// # Errors + /// + /// Error is returned when + /// - the model is invalid, or + /// - `wsconst` contains an invalid character type. + pub fn new(model: Model, wsconst: &str) -> Result> { + let mut postfilters: Vec> = vec![Arc::new(SplitLinebreaksFilter)]; + for c in wsconst.chars() { + postfilters.push(match c { + 'D' => Arc::new(KyteaWsConstFilter::new(CharacterType::Digit)), + 'R' => Arc::new(KyteaWsConstFilter::new(CharacterType::Roman)), + 'H' => Arc::new(KyteaWsConstFilter::new(CharacterType::Hiragana)), + 'T' => Arc::new(KyteaWsConstFilter::new(CharacterType::Katakana)), + 'K' => Arc::new(KyteaWsConstFilter::new(CharacterType::Kanji)), + 'O' => Arc::new(KyteaWsConstFilter::new(CharacterType::Other)), + 'G' => Arc::new(ConcatGraphemeClustersFilter), + _ => return Err("Could not parse a wsconst value".into()), + }); + } + Ok(Self { + predictor: Arc::new(Predictor::new(model, false)?), + prefilter: KyteaFullwidthFilter, + postfilters, + }) + } +} + +pub struct VaporettoTokenStream<'a> { + text: &'a str, + token: Token, + boundary_pos: Vec, + offset_to: usize, + position: usize, +} + +impl Tokenizer for VaporettoTokenizer { + fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> { + if text.is_empty() { + return BoxTokenStream::from(VaporettoTokenStream { + text, + boundary_pos: vec![], + token: Token::default(), + offset_to: 0, + position: 0, + }); + } + + // pre filter + let prefiltered_text = self.prefilter.filter(text); + let prefiltered_sentence = Sentence::from_raw(prefiltered_text).unwrap(); + + // tokenize + let tokenized_sentence = self.predictor.predict(prefiltered_sentence); + + // post filter + let postfiltered_sentence = self + .postfilters + .iter() + .fold(tokenized_sentence, |s, filter| filter.filter(s)); + + let mut char_indices = text.char_indices(); + char_indices.next(); + let mut boundary_pos = Vec::with_capacity(postfiltered_sentence.chars().len()); + for ((i, _), &b) in char_indices.zip(postfiltered_sentence.boundaries()) { + if b == BoundaryType::WordBoundary { + boundary_pos.push(i); + } + } + boundary_pos.push(text.len()); + + BoxTokenStream::from(VaporettoTokenStream { + text, + token: Token::default(), + boundary_pos, + offset_to: 0, + position: 0, + }) + } +} + +impl<'a> TokenStream for VaporettoTokenStream<'a> { + fn advance(&mut self) -> bool { + if self.position < self.boundary_pos.len() { + self.token.offset_from = self.offset_to; + self.offset_to = self.boundary_pos[self.position]; + self.token.offset_to = self.offset_to; + self.token.text.clear(); + self.token + .text + .push_str(&self.text[self.token.offset_from..self.token.offset_to]); + self.token.position = self.position; + self.token.position_length = self.boundary_pos.len(); + self.position += 1; + true + } else { + false + } + } + + fn token(&self) -> &Token { + &self.token + } + + fn token_mut(&mut self) -> &mut Token { + &mut self.token + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::{Cursor, Read}; + + use tantivy::tokenizer::TextAnalyzer; + + fn token_stream_helper(text: &str, wsconst: &str) -> Vec { + let mut f = Cursor::new(include_bytes!("../test_model/model.zst")); + let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap(); + let mut buff = vec![]; + decoder.read_to_end(&mut buff).unwrap(); + let model = Model::read(&mut buff.as_slice()).unwrap(); + let a = TextAnalyzer::from(VaporettoTokenizer::new(model, wsconst).unwrap()); + let mut token_stream = a.token_stream(text); + let mut tokens: Vec = vec![]; + let mut add_token = |token: &Token| { + tokens.push(token.clone()); + }; + token_stream.process(&mut add_token); + tokens + } + + #[test] + fn test_tokenize_empty() { + let tokens = token_stream_helper("", ""); + + assert_eq!(tokens.len(), 0); + } + + #[test] + fn test_tokenizer_tokyo() { + let tokens = token_stream_helper("東京特許許可局", ""); + + assert_eq!(tokens.len(), 4); + + let token = &tokens[0]; + assert_eq!(token.text, "東京"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 4); + + let token = &tokens[1]; + assert_eq!(token.text, "特許"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 12); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 4); + + let token = &tokens[2]; + assert_eq!(token.text, "許可"); + assert_eq!(token.offset_from, 12); + assert_eq!(token.offset_to, 18); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 4); + + let token = &tokens[3]; + assert_eq!(token.text, "局"); + assert_eq!(token.offset_from, 18); + assert_eq!(token.offset_to, 21); + assert_eq!(token.position, 3); + assert_eq!(token.position_length, 4); + } + + #[test] + fn test_tokenizer_no_wsconst() { + let tokens = token_stream_helper("123456円🤌🏿", ""); + + assert_eq!(tokens.len(), 9); + + let token = &tokens[0]; + assert_eq!(token.text, "1"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 1); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 9); + + let token = &tokens[1]; + assert_eq!(token.text, "2"); + assert_eq!(token.offset_from, 1); + assert_eq!(token.offset_to, 2); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 9); + + let token = &tokens[2]; + assert_eq!(token.text, "3"); + assert_eq!(token.offset_from, 2); + assert_eq!(token.offset_to, 3); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 9); + + let token = &tokens[3]; + assert_eq!(token.text, "4"); + assert_eq!(token.offset_from, 3); + assert_eq!(token.offset_to, 4); + assert_eq!(token.position, 3); + assert_eq!(token.position_length, 9); + + let token = &tokens[4]; + assert_eq!(token.text, "5"); + assert_eq!(token.offset_from, 4); + assert_eq!(token.offset_to, 5); + assert_eq!(token.position, 4); + assert_eq!(token.position_length, 9); + + let token = &tokens[5]; + assert_eq!(token.text, "6"); + assert_eq!(token.offset_from, 5); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 5); + assert_eq!(token.position_length, 9); + + let token = &tokens[6]; + assert_eq!(token.text, "円"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 9); + assert_eq!(token.position, 6); + assert_eq!(token.position_length, 9); + + let token = &tokens[7]; + assert_eq!(token.text, "🤌"); + assert_eq!(token.offset_from, 9); + assert_eq!(token.offset_to, 13); + assert_eq!(token.position, 7); + assert_eq!(token.position_length, 9); + + let token = &tokens[8]; + assert_eq!(token.text, "🏿"); + assert_eq!(token.offset_from, 13); + assert_eq!(token.offset_to, 17); + assert_eq!(token.position, 8); + assert_eq!(token.position_length, 9); + } + + #[test] + fn test_tokenize_wsconst_d() { + let tokens = token_stream_helper("123456円🤌🏿", "D"); + + assert_eq!(tokens.len(), 4); + + let token = &tokens[0]; + assert_eq!(token.text, "123456"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 4); + + let token = &tokens[1]; + assert_eq!(token.text, "円"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 9); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 4); + + let token = &tokens[2]; + assert_eq!(token.text, "🤌"); + assert_eq!(token.offset_from, 9); + assert_eq!(token.offset_to, 13); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 4); + + let token = &tokens[3]; + assert_eq!(token.text, "🏿"); + assert_eq!(token.offset_from, 13); + assert_eq!(token.offset_to, 17); + assert_eq!(token.position, 3); + assert_eq!(token.position_length, 4); + } + + #[test] + fn test_tokenizer_wsconst_g() { + let tokens = token_stream_helper("123456円🤌🏿", "G"); + + assert_eq!(tokens.len(), 8); + + let token = &tokens[0]; + assert_eq!(token.text, "1"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 1); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 8); + + let token = &tokens[1]; + assert_eq!(token.text, "2"); + assert_eq!(token.offset_from, 1); + assert_eq!(token.offset_to, 2); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 8); + + let token = &tokens[2]; + assert_eq!(token.text, "3"); + assert_eq!(token.offset_from, 2); + assert_eq!(token.offset_to, 3); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 8); + + let token = &tokens[3]; + assert_eq!(token.text, "4"); + assert_eq!(token.offset_from, 3); + assert_eq!(token.offset_to, 4); + assert_eq!(token.position, 3); + assert_eq!(token.position_length, 8); + + let token = &tokens[4]; + assert_eq!(token.text, "5"); + assert_eq!(token.offset_from, 4); + assert_eq!(token.offset_to, 5); + assert_eq!(token.position, 4); + assert_eq!(token.position_length, 8); + + let token = &tokens[5]; + assert_eq!(token.text, "6"); + assert_eq!(token.offset_from, 5); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 5); + assert_eq!(token.position_length, 8); + + let token = &tokens[6]; + assert_eq!(token.text, "円"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 9); + assert_eq!(token.position, 6); + assert_eq!(token.position_length, 8); + + let token = &tokens[7]; + assert_eq!(token.text, "🤌🏿"); + assert_eq!(token.offset_from, 9); + assert_eq!(token.offset_to, 17); + assert_eq!(token.position, 7); + assert_eq!(token.position_length, 8); + } + + #[test] + fn test_tokenize_wsconst_dg() { + let tokens = token_stream_helper("123456円🤌🏿", "DG"); + + assert_eq!(tokens.len(), 3); + + let token = &tokens[0]; + assert_eq!(token.text, "123456"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 3); + + let token = &tokens[1]; + assert_eq!(token.text, "円"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 9); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 3); + + let token = &tokens[2]; + assert_eq!(token.text, "🤌🏿"); + assert_eq!(token.offset_from, 9); + assert_eq!(token.offset_to, 17); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 3); + } +} diff --git a/vaporetto_tantivy/test_model/model.zst b/vaporetto_tantivy/test_model/model.zst new file mode 100644 index 00000000..e51157d3 Binary files /dev/null and b/vaporetto_tantivy/test_model/model.zst differ