From 58ef5f3c8cdb55df1ad7969f7df338b92da58c45 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 6 Jun 2022 20:47:44 +0900 Subject: [PATCH] Reimplement Vaporetto with supporting multiple tags (#35) * Reimplement Vaporetto with supporting multiple tags * Update examples * fmt all * clippy and fix bugs * fmt * fix bugs * Add charwise-pma feature * fmt * clippy * fix * fix features * docs * Implement Eq to CharacterBoundary * fix features * fix CI * Update README * update README-ja * fix a bug * fix README * Update vaporetto/src/dict_model.rs Co-authored-by: Shunsuke Kanda * Update vaporetto/src/utils.rs Co-authored-by: Shunsuke Kanda * Update vaporetto/src/dict_model.rs Co-authored-by: Shunsuke Kanda * Update docs of Sentence * update docs * Add a description of Sentence::default() * Update docs * fix docs * fmt * fix minor * fix lint * fix doc * Update vaporetto/src/sentence.rs Co-authored-by: Shunsuke Kanda * Update vaporetto/src/sentence.rs Co-authored-by: Shunsuke Kanda * Update vaporetto/src/sentence.rs Co-authored-by: Shunsuke Kanda * Update vaporetto/src/sentence.rs Co-authored-by: Shunsuke Kanda * Update vaporetto/src/sentence.rs Co-authored-by: Shunsuke Kanda * Update vaporetto/src/sentence.rs Co-authored-by: Shunsuke Kanda * fix docs * Update vaporetto/src/sentence.rs Co-authored-by: Shunsuke Kanda * fix docs * update docs * str_to_char_pos * fix * Update vaporetto/src/sentence.rs Co-authored-by: Shunsuke Kanda * Update sentence.rs * debug_assert() for daachorse * Update vaporetto/src/char_scorer.rs Co-authored-by: Shunsuke Kanda * fix * fix arg names * fix * fix * check window_size * fix * Add Nop * Revert "Add Nop" This reverts commit 3b8d9d9efe59a9e2759b1f77da4972e31f355666. * Revert "fix" This reverts commit ccf3657223019496bd3c9d5b196ce25d47fb45e8. * Revert "check window_size" This reverts commit d17062e85c4cc18ffc63d9f699c8f629214ab9e7. * fix serialization * fix * wrap tag_predictor by Option * add panic test * add comments for unsafe blocks * fix * add trim_end_zeros() * refactoring * add TagWeight * Add description of TagModel * Update model.rs * Use HashMap in model * fix * tantivy 0.18 * fix * Add wrapper of HashMap for bincode * fix * fix * fix * fix * fix * add cfg annotation * fix * fix * Use Vec in Model * Apply suggestions from code review Co-authored-by: Shunsuke Kanda * Add debug_assert!() * Remove redundant clone() * fix * fix * Apply suggestions from code review Co-authored-by: Shunsuke Kanda Co-authored-by: Shunsuke Kanda --- .github/workflows/rust.yml | 12 +- README-ja.md | 10 +- README.md | 8 +- evaluate/src/main.rs | 45 +- examples/embedded_device/Cargo.toml | 2 +- examples/embedded_device/src/main.rs | 28 +- examples/wasm/src/lib.rs | 49 +- examples/wasm/www/index.js | 2 +- predict/src/main.rs | 41 +- resources/docs.tok | 2 + resources/model.bin | Bin 0 -> 394 bytes train/src/main.rs | 29 +- vaporetto/Cargo.toml | 16 +- vaporetto/README.md | 39 +- vaporetto/src/char_scorer.rs | 935 +++-- vaporetto/src/char_scorer/boundary_scorer.rs | 113 + .../src/char_scorer/boundary_tag_scorer.rs | 185 + vaporetto/src/dict_model.rs | 19 +- vaporetto/src/errors.rs | 40 +- vaporetto/src/feature.rs | 769 ---- vaporetto/src/kytea_model.rs | 7 +- vaporetto/src/lib.rs | 65 +- vaporetto/src/model.rs | 106 +- vaporetto/src/ngram_model.rs | 21 +- vaporetto/src/predictor.rs | 1783 +++++----- vaporetto/src/sentence.rs | 3095 +++++++++-------- vaporetto/src/tag_model.rs | 28 - vaporetto/src/tag_trainer.rs | 417 ++- vaporetto/src/trainer.rs | 798 ++++- vaporetto/src/type_scorer.rs | 599 +++- vaporetto/src/type_scorer/boundary_scorer.rs | 79 + .../src/type_scorer/boundary_scorer_cache.rs | 109 + .../src/type_scorer/boundary_tag_scorer.rs | 152 + vaporetto/src/utils.rs | 223 +- vaporetto_rules/Cargo.toml | 8 +- vaporetto_rules/README.md | 22 +- vaporetto_rules/src/lib.rs | 46 +- .../concat_grapheme_clusters.rs | 72 +- .../src/sentence_filters/kytea_wsconst.rs | 47 +- .../src/sentence_filters/split_linebreaks.rs | 59 +- vaporetto_rules/src/string_filters.rs | 2 +- .../src/string_filters/kytea_fullwidth.rs | 18 +- vaporetto_tantivy/Cargo.toml | 8 +- vaporetto_tantivy/src/lib.rs | 17 +- vaporetto_tantivy/test_model/model.zst | Bin 270 -> 332 bytes 45 files changed, 5477 insertions(+), 4648 deletions(-) create mode 100644 resources/docs.tok create mode 100644 resources/model.bin create mode 100644 vaporetto/src/char_scorer/boundary_scorer.rs create mode 100644 vaporetto/src/char_scorer/boundary_tag_scorer.rs delete mode 100644 vaporetto/src/feature.rs delete mode 100644 vaporetto/src/tag_model.rs create mode 100644 vaporetto/src/type_scorer/boundary_scorer.rs create mode 100644 vaporetto/src/type_scorer/boundary_scorer_cache.rs create mode 100644 vaporetto/src/type_scorer/boundary_tag_scorer.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4ae3c3af..022ca394 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -34,7 +34,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: clippy - args: -- -D warnings -W clippy::nursery -W clippy::cast_lossless -W clippy::cast_possible_truncation -W clippy::cast_possible_wrap + args: -- -D warnings -W clippy::nursery -W clippy::cast_lossless -W clippy::cast_possible_truncation -W clippy::cast_possible_wrap -A clippy::empty_line_after_outer_attr - name: Run cargo test (workspace) uses: actions-rs/cargo@v1 @@ -46,7 +46,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --release -p vaporetto --no-default-features + args: --release -p vaporetto --no-default-features --features alloc - name: Run cargo test (vaporetto / features kytea) uses: actions-rs/cargo@v1 @@ -82,7 +82,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --release -p vaporetto --no-default-features --features charwise-daachorse + args: --release -p vaporetto --no-default-features --features charwise-pma - name: Run cargo test (vaporetto / features std) uses: actions-rs/cargo@v1 @@ -117,7 +117,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: clippy - args: -- -D warnings -W clippy::nursery -W clippy::cast_lossless -W clippy::cast_possible_truncation -W clippy::cast_possible_wrap + args: -- -D warnings -W clippy::nursery -W clippy::cast_lossless -W clippy::cast_possible_truncation -W clippy::cast_possible_wrap -A clippy::empty_line_after_outer_attr - name: Run cargo test (workspace) uses: actions-rs/cargo@v1 @@ -129,7 +129,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --release -p vaporetto --no-default-features + args: --release -p vaporetto --no-default-features --features alloc - name: Run cargo test (vaporetto / all-features) uses: actions-rs/cargo@v1 @@ -171,7 +171,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --release -p vaporetto --no-default-features --features charwise-daachorse + args: --release -p vaporetto --no-default-features --features charwise-pma - name: Run cargo test (vaporetto / features std) uses: actions-rs/cargo@v1 diff --git a/README-ja.md b/README-ja.md index 14045808..13e0358b 100644 --- a/README-ja.md +++ b/README-ja.md @@ -187,13 +187,13 @@ Vaporetto は2種類のコーパス、すなわちフルアノテーションコ ### 品詞推定 -Vaporettoは実験的に品詞推定に対応しています。 +Vaporettoは実験的にタグ推定(品詞推定や読み推定)に対応しています。 -品詞を学習するには、以下のように、データセットの各トークンに続けてスラッシュと品詞を追加します。 +タグを学習するには、以下のように、データセットの各トークンに続けてスラッシュとタグを追加します。 * フルアノテーションコーパスの場合 ``` - この/連体詞 人/名詞 は/助詞 火星/名詞 人/接尾辞 です/助動詞 + この/連体詞/コノ 人/名詞/ヒト は/助詞/ワ 火星/名詞/カセイ 人/接尾辞/ジン です/助動詞/デス ``` * 部分アノテーションコーパスの場合 @@ -201,9 +201,9 @@ Vaporettoは実験的に品詞推定に対応しています。 ヴ-ェ-ネ-ツ-ィ-ア/名詞|は/助詞|イ-タ-リ-ア/名詞|に/助詞|あ-り ま-す ``` -データセットに品詞が含まれる場合、 `train` コマンドは自動的にそれらを学習します。 +データセットにタグが含まれる場合、 `train` コマンドは自動的にそれらを学習します。 -推定時は、デフォルトでは品詞は推定されないため、必要に応じで `predict` コマンドに `--predict-tags` 引数を指定してください。 +推定時は、デフォルトではタグは推定されないため、必要に応じで `predict` コマンドに `--predict-tags` 引数を指定してください。 ## 各種トークナイザの速度比較 diff --git a/README.md b/README.md index cd43b44c..535b5619 100644 --- a/README.md +++ b/README.md @@ -186,15 +186,15 @@ Now `外国人参政権` is split into correct tokens. 9:交代 -7658 ``` -### POS tagging +### Tagging -Vaporetto experimentally supports POS tagging. +Vaporetto experimentally supports tagging (e.g., part-of-speech and pronunciation tags). -To train tags, add a slash and tag name following each token in the dataset as follows: +To train tags, add slashes and tags following each token in the dataset as follows: * For fully annotated corpora ``` - この/連体詞 人/名詞 は/助詞 火星/名詞 人/接尾辞 です/助動詞 + この/連体詞/コノ 人/名詞/ヒト は/助詞/ワ 火星/名詞/カセイ 人/接尾辞/ジン です/助動詞/デス ``` * For partially annotated corpora diff --git a/evaluate/src/main.rs b/evaluate/src/main.rs index a0b266c8..dd1f32c9 100644 --- a/evaluate/src/main.rs +++ b/evaluate/src/main.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use std::str::FromStr; use clap::Parser; -use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence}; +use vaporetto::{CharacterBoundary, CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter}, string_filters::KyteaFullwidthFilter, @@ -107,19 +107,25 @@ fn main() -> Result<(), Box> { if line.is_empty() { continue; } - let mut s = Sentence::from_tokenized(line)?; + let mut s = Sentence::from_tokenized(&line)?; let ref_boundaries = s.boundaries().to_vec(); - let ref_tags = s.tags().to_vec(); + let mut ref_tags = vec![]; + for i in 0..=ref_boundaries.len() { + ref_tags.push(s.tags()[i * s.n_tags()..(i + 1) * s.n_tags()].to_vec()); + } if !args.no_norm { - let new_line = fullwidth_filter.filter(s.to_raw_string()); + let new_line = fullwidth_filter.filter(s.as_raw_text()); s = Sentence::from_raw(new_line)? }; - s = predictor.predict(s); - s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); - s = predictor.fill_tags(s); - let hyp_boundaries = s.boundaries().to_vec(); - let hyp_tags = s.tags().to_vec(); - results.push((ref_boundaries, ref_tags, hyp_boundaries, hyp_tags)); + predictor.predict(&mut s); + post_filters.iter().for_each(|filter| filter.filter(&mut s)); + s.fill_tags(); + let sys_boundaries = s.boundaries().to_vec(); + let mut sys_tags = vec![]; + for i in 0..=sys_boundaries.len() { + sys_tags.push(s.tags()[i * s.n_tags()..(i + 1) * s.n_tags()].to_vec()); + } + results.push((ref_boundaries, ref_tags, sys_boundaries, sys_tags)); } match args.metric { @@ -131,12 +137,12 @@ fn main() -> Result<(), Box> { for (rs_b, _, hs_b, _) in results { for (r, h) in rs_b.into_iter().zip(hs_b) { if r == h { - if h == BoundaryType::WordBoundary { + if h == CharacterBoundary::WordBoundary { n_tp += 1; } else { n_tn += 1; } - } else if h == BoundaryType::WordBoundary { + } else if h == CharacterBoundary::WordBoundary { n_fp += 1; } else { n_fn += 1; @@ -159,12 +165,13 @@ fn main() -> Result<(), Box> { let mut n_sys = 0; let mut n_ref = 0; let mut n_cor = 0; - for (rs_b, rs_t, hs_b, hs_t) in results { + for (refs_b, refs_t, syss_b, syss_t) in results { let mut matched = true; - for (((r_b, r_t), h_b), h_t) in rs_b.iter().zip(&rs_t).zip(&hs_b).zip(&hs_t) { - if r_b == h_b { - if *h_b == BoundaryType::WordBoundary { - if matched && r_t == h_t { + for (((r_b, r_t), s_b), s_t) in refs_b.iter().zip(&refs_t).zip(&syss_b).zip(&syss_t) + { + if r_b == s_b { + if *s_b == CharacterBoundary::WordBoundary { + if matched && r_t == s_t { n_cor += 1; } matched = true; @@ -172,7 +179,7 @@ fn main() -> Result<(), Box> { n_sys += 1; } } else { - if *h_b == BoundaryType::WordBoundary { + if *s_b == CharacterBoundary::WordBoundary { n_sys += 1; } else { n_ref += 1; @@ -180,7 +187,7 @@ fn main() -> Result<(), Box> { matched = false; } } - if matched && rs_t.last().unwrap() == hs_t.last().unwrap() { + if matched && refs_t.last().unwrap() == syss_t.last().unwrap() { n_cor += 1; } n_sys += 1; diff --git a/examples/embedded_device/Cargo.toml b/examples/embedded_device/Cargo.toml index d23415dd..473f377f 100644 --- a/examples/embedded_device/Cargo.toml +++ b/examples/embedded_device/Cargo.toml @@ -15,7 +15,7 @@ vaporetto_rules = { path = "../../vaporetto_rules" } alloc-cortex-m = "0.4.0" [build-dependencies] -vaporetto = { path = "../../vaporetto", default-features = false } +vaporetto = { path = "../../vaporetto", default-features = false, features = ["alloc"] } ruzstd = "0.2.4" # MIT [profile.release] diff --git a/examples/embedded_device/src/main.rs b/examples/embedded_device/src/main.rs index 8bf76c02..e1d94e90 100644 --- a/examples/embedded_device/src/main.rs +++ b/examples/embedded_device/src/main.rs @@ -8,7 +8,7 @@ extern crate alloc; use core::alloc::Layout; // alloc crate -use alloc::vec::Vec; +use alloc::string::String; // devices use alloc_cortex_m::CortexMHeap; @@ -17,11 +17,8 @@ use cortex_m_rt::entry; use cortex_m_semihosting::hprintln; // other crates -use vaporetto::{Predictor, Sentence, CharacterType}; -use vaporetto_rules::{ - sentence_filters::KyteaWsConstFilter, - SentenceFilter, -}; +use vaporetto::{CharacterType, Predictor, Sentence}; +use vaporetto_rules::{sentence_filters::KyteaWsConstFilter, SentenceFilter}; // panic behaviour use panic_halt as _; @@ -36,22 +33,23 @@ fn main() -> ! { unsafe { ALLOCATOR.init(cortex_m_rt::heap_start() as usize, HEAP_SIZE) } let predictor_data = include_bytes!(concat!(env!("OUT_DIR"), "/predictor.bin")); - let (predictor, _) = unsafe { Predictor::deserialize_from_slice_unchecked(predictor_data) }.unwrap(); + let (predictor, _) = + unsafe { Predictor::deserialize_from_slice_unchecked(predictor_data) }.unwrap(); - let docs = &[ - "🚤VaporettoはSTM32F303VCT6(FLASH:256KiB,RAM:40KiB)などの小さなデバイスでも動作します", - ]; + let docs = + &["🚤VaporettoはSTM32F303VCT6(FLASH:256KiB,RAM:40KiB)などの小さなデバイスでも動作します"]; let wsconst_d_filter = KyteaWsConstFilter::new(CharacterType::Digit); loop { for &text in docs { hprintln!("\x1b[32mINPUT:\x1b[m {:?}", text).unwrap(); - let s = Sentence::from_raw(text).unwrap(); - let s = predictor.predict(s); - let s = wsconst_d_filter.filter(s); - let v = s.to_tokenized_vec().unwrap().iter().map(|t| t.surface).collect::>(); - hprintln!("\x1b[31mOUTPUT:\x1b[m {:?}", v).unwrap(); + let mut s = Sentence::from_raw(text).unwrap(); + predictor.predict(&mut s); + wsconst_d_filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + hprintln!("\x1b[31mOUTPUT:\x1b[m {}", buf).unwrap(); } } } diff --git a/examples/wasm/src/lib.rs b/examples/wasm/src/lib.rs index 5db9d13d..11d34f28 100644 --- a/examples/wasm/src/lib.rs +++ b/examples/wasm/src/lib.rs @@ -1,7 +1,7 @@ use std::io::{Cursor, Read}; use js_sys::{Array, Object}; -use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence}; +use vaporetto::{CharacterBoundary, CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter}, string_filters::KyteaFullwidthFilter, @@ -61,22 +61,19 @@ impl Vaporetto { return result.into(); }; let norm = self.fullwidth_filter.filter(text); - let s_norm = if let Ok(s) = Sentence::from_raw(norm) { + let mut s_norm = if let Ok(s) = Sentence::from_raw(norm) { s } else { return result.into(); }; - let s_norm = self.predictor.predict(s_norm); + self.predictor.predict(&mut s_norm); s.boundaries_mut().clone_from_slice(s_norm.boundaries()); - let s = self - .post_filters + self.post_filters .iter() - .fold(s, |s, filter| filter.filter(s)); + .for_each(|filter| filter.filter(&mut s)); - if let Ok(tokens) = s.to_tokenized_vec() { - for token in tokens { - result.push(&JsValue::from_str(token.surface)); - } + for token in s.iter_tokens() { + result.push(&JsValue::from_str(token.surface())); } result.into() } @@ -85,41 +82,19 @@ impl Vaporetto { pub fn predict(&self, text: &str) -> Object { let result = Array::new(); let text = self.fullwidth_filter.filter(text); - let s = if let Ok(s) = Sentence::from_raw(text) { - s - } else { - return result.into(); - }; - let s = self.predictor.predict(s); - let s = self - .post_filters - .iter() - .fold(s, |s, filter| filter.filter(s)); - - for &b in s.boundaries() { - result.push(&JsValue::from_bool(b == BoundaryType::WordBoundary)); - } - result.into() - } - - #[wasm_bindgen] - pub fn predict_with_score(&self, text: &str) -> Object { - let result = Array::new(); - let text = self.fullwidth_filter.filter(text); - let s = if let Ok(s) = Sentence::from_raw(text) { + let mut s = if let Ok(s) = Sentence::from_raw(text) { s } else { return result.into(); }; - let s = self.predictor.predict_with_score(s); - let s = self - .post_filters + self.predictor.predict(&mut s); + self.post_filters .iter() - .fold(s, |s, filter| filter.filter(s)); + .for_each(|filter| filter.filter(&mut s)); for (&score, &b) in s.boundary_scores().iter().zip(s.boundaries()) { let boundary = Array::new(); - boundary.push(&(b == BoundaryType::WordBoundary).into()); + boundary.push(&(b == CharacterBoundary::WordBoundary).into()); boundary.push(&score.into()); result.push(&boundary); } diff --git a/examples/wasm/www/index.js b/examples/wasm/www/index.js index 47496fe1..e86708f5 100644 --- a/examples/wasm/www/index.js +++ b/examples/wasm/www/index.js @@ -13,7 +13,7 @@ vaporetto_bccwj_suw_small().then((Vaporetto) => { input_text.addEventListener("input", (e) => { const text = input_text.value; - const scores = vaporetto_suw.predict_with_score(text); + const scores = vaporetto_suw.predict(text); let i = -1; while (tokenized.firstChild) { tokenized.removeChild(tokenized.firstChild); diff --git a/predict/src/main.rs b/predict/src/main.rs index 94a73519..64c6076b 100644 --- a/predict/src/main.rs +++ b/predict/src/main.rs @@ -64,8 +64,11 @@ struct Args { } fn print_scores(s: &Sentence, out: &mut dyn Write) -> Result<(), Box> { - for (i, score) in s.boundary_scores().iter().enumerate() { - writeln!(out, "{}:{}{} {}", i, s.chars()[i], s.chars()[i + 1], score)?; + let mut chars_iter = s.as_raw_text().chars(); + let mut prev_c = chars_iter.next().unwrap(); + for (i, (c, score)) in chars_iter.zip(s.boundary_scores()).enumerate() { + writeln!(out, "{}:{}{} {}", i, prev_c, c, score)?; + prev_c = c; } writeln!(out)?; Ok(()) @@ -92,28 +95,23 @@ fn main() -> Result<(), Box> { eprintln!("Start tokenization"); let start = Instant::now(); - let stdout = io::stdout(); let mut out: Box = if args.buffered_out { - Box::new(BufWriter::new(stdout.lock())) + Box::new(BufWriter::new(io::stdout().lock())) } else { - Box::new(stdout.lock()) + Box::new(io::stdout().lock()) }; let mut buf = String::new(); - let mut s = Sentence::from_raw(" ")?; + let mut s = Sentence::default(); if args.no_norm { for line in io::stdin().lock().lines() { let line = line?; if s.update_raw(line).is_ok() { - s = if args.scores { - predictor.predict_with_score(s) - } else { - predictor.predict(s) - }; - s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); + predictor.predict(&mut s); + post_filters.iter().for_each(|filter| filter.filter(&mut s)); if args.predict_tags { - s = predictor.fill_tags(s); + s.fill_tags(); } - s.write_tokenized_string(&mut buf)?; + s.write_tokenized_text(&mut buf); writeln!(out, "{}", buf)?; if args.scores { print_scores(&s, &mut *out)?; @@ -123,24 +121,21 @@ fn main() -> Result<(), Box> { } } } else { - let mut s_orig = Sentence::from_raw(" ")?; + let mut s_orig = Sentence::default(); for line in io::stdin().lock().lines() { let line = line?; let line_preproc = pre_filter.filter(&line); if s.update_raw(line_preproc).is_ok() { - s = if args.scores { - predictor.predict_with_score(s) - } else { - predictor.predict(s) - }; - s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); + predictor.predict(&mut s); + post_filters.iter().for_each(|filter| filter.filter(&mut s)); if args.predict_tags { - s = predictor.fill_tags(s); + s.fill_tags(); } s_orig.update_raw(line)?; + s_orig.reset_tags(s.n_tags()); s_orig.boundaries_mut().copy_from_slice(s.boundaries()); s_orig.tags_mut().clone_from_slice(s.tags()); - s_orig.write_tokenized_string(&mut buf)?; + s_orig.write_tokenized_text(&mut buf); writeln!(out, "{}", buf)?; if args.scores { print_scores(&s, &mut *out)?; diff --git a/resources/docs.tok b/resources/docs.tok new file mode 100644 index 00000000..2fefa1d4 --- /dev/null +++ b/resources/docs.tok @@ -0,0 +1,2 @@ +まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ 火星/名詞/カセー 猫/名詞/ネコ だ/助動詞/ダ +まぁ/副詞/マー 良い/形容詞/ヨイ だろう/助動詞/ダロー diff --git a/resources/model.bin b/resources/model.bin new file mode 100644 index 0000000000000000000000000000000000000000..09ffc3e31a615637cab4731b6c2c6606f5214039 GIT binary patch literal 394 zcmY+8y-EW?6ouy}Gq@u71fr$Y5(G=ZhY(A_Vha&m38}1{84(g)unAa2(Lh`m-9j*k znwPn(n^Zo)7}q=OK(LwNoOADY?(Y6^+=|-mcqcxLjt)+u)><>%44aFLgHnBSOdre#U8)<)!FbvmQO0o5&MMkKJa=+dp~LX9Qe|J)H;j4?4M1n0 z-;v>p%FApXAv*#6;(SoIj0QzFEiQVlmV-G(KR_RJ0_yx<8ej_MsbZ@~H`~x=Yr`-Sm literal 0 HcmV?d00001 diff --git a/train/src/main.rs b/train/src/main.rs index 0cad6499..3927e6ee 100644 --- a/train/src/main.rs +++ b/train/src/main.rs @@ -46,7 +46,8 @@ struct Args { #[clap(long, default_value = "3")] typen: u8, - /// Dictionary words greater than this value will be grouped together + /// Dictionary words longer than this value will be grouped together, where the length is in + /// characters #[clap(long, default_value = "4")] dictn: u8, @@ -84,13 +85,14 @@ fn main() -> Result<(), Box> { eprint!("# of sentences: {}\r", i); stderr().flush()?; } - let s = Sentence::from_tokenized(line?)?; + let s = Sentence::from_tokenized(&line?)?; let s = if args.no_norm { s } else { - let new_line = fullwidth_filter.filter(s.to_raw_string()); + let new_line = fullwidth_filter.filter(s.as_raw_text()); let mut new_s = Sentence::from_raw(new_line)?; new_s.boundaries_mut().clone_from_slice(s.boundaries()); + new_s.reset_tags(s.n_tags()); new_s.tags_mut().clone_from_slice(s.tags()); new_s }; @@ -107,13 +109,14 @@ fn main() -> Result<(), Box> { eprint!("# of sentences: {}\r", i); stderr().flush()?; } - let s = Sentence::from_partial_annotation(line?)?; + let s = Sentence::from_partial_annotation(&line?)?; let s = if args.no_norm { s } else { - let new_line = fullwidth_filter.filter(s.to_raw_string()); + let new_line = fullwidth_filter.filter(s.as_raw_text()); let mut new_s = Sentence::from_raw(new_line)?; new_s.boundaries_mut().copy_from_slice(s.boundaries()); + new_s.reset_tags(s.n_tags()); new_s.tags_mut().clone_from_slice(s.tags()); new_s }; @@ -146,24 +149,16 @@ fn main() -> Result<(), Box> { eprintln!("Extracting into features..."); let mut trainer = Trainer::new( - args.charn, args.charw, args.typen, args.typew, dictionary, args.dictn, + args.charw, args.charn, args.typew, args.typen, dictionary, args.dictn, )?; for (i, s) in train_sents.iter().enumerate() { if i % 10000 == 0 { - eprint!( - "# of features: {}, # of tag features: {}\r", - trainer.n_features(), - trainer.n_tag_features() - ); + eprint!("# of features: {}\r", trainer.n_features(),); stderr().flush()?; } - trainer.push_sentence(s)?; + trainer.add_example(s); } - eprintln!( - "# of features: {}, # of tag features: {}", - trainer.n_features(), - trainer.n_tag_features() - ); + eprintln!("# of features: {}", trainer.n_features(),); eprintln!("Start training..."); let model = trainer.train(args.eps, args.cost, args.solver)?; diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index 114d5772..dff801f9 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vaporetto" -version = "0.4.0" +version = "0.5.0" edition = "2021" authors = ["Koichi Akabe "] description = "Vaporetto: a pointwise prediction based tokenizer" @@ -14,22 +14,24 @@ categories = ["text-processing", "no-std"] [dependencies] bincode = { version = "2.0.0-rc.1", default-features = false, features = ["alloc", "derive"] } # MIT daachorse = { version = "0.4.1", default-features = false } # MIT or Apache-2.0 +hashbrown = "0.12.1" # MIT or Apache-2.0 liblinear = { version = "1", optional = true } # MIT [features] -default = ["std", "cache-type-score", "fix-weight-length", "tag-prediction"] +default = ["std", "cache-type-score", "fix-weight-length", "tag-prediction", "charwise-pma"] # default: on -std = ["bincode/std"] -cache-type-score = [] -fix-weight-length = [] -tag-prediction = ["bincode/atomic"] +alloc = [] +std = ["alloc", "bincode/std"] +cache-type-score = ["alloc"] +fix-weight-length = ["alloc"] +tag-prediction = ["alloc", "bincode/atomic"] +charwise-pma = ["alloc"] kytea = ["std"] train = ["std", "liblinear"] portable-simd = ["fix-weight-length"] -charwise-daachorse = [] [package.metadata.docs.rs] all-features = true diff --git a/vaporetto/README.md b/vaporetto/README.md index 373de723..9bc53a87 100644 --- a/vaporetto/README.md +++ b/vaporetto/README.md @@ -6,21 +6,34 @@ Vaporetto is a fast and lightweight pointwise prediction based tokenizer. ```rust use std::fs::File; -use std::io::Read; use vaporetto::{Model, Predictor, Sentence}; -let mut f = File::open("model.bin").unwrap(); -let mut model_data = vec![]; -f.read_to_end(&mut model_data).unwrap(); -let (model, _) = Model::read_slice(&model_data).unwrap(); -let predictor = Predictor::new(model, false).unwrap(); - -let s = Sentence::from_raw("火星猫の生態").unwrap(); -let s = predictor.predict(s); - -println!("{:?}", s.to_tokenized_vec().unwrap()); -// ["火星", "猫", "の", "生態"] +let f = File::open("../resources/model.bin").unwrap(); +let model = Model::read(f).unwrap(); +let predictor = Predictor::new(model, true).unwrap(); + +let mut buf = String::new(); + +let mut s = Sentence::default(); + +s.update_raw("まぁ社長は火星猫だ").unwrap(); +predictor.predict(&mut s); +s.fill_tags(); +s.write_tokenized_text(&mut buf); +assert_eq!( + "まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ 火星/名詞/カセー 猫/名詞/ネコ だ/助動詞/ダ", + buf, +); + +s.update_raw("まぁ良いだろう").unwrap(); +predictor.predict(&mut s); +s.fill_tags(); +s.write_tokenized_text(&mut buf); +assert_eq!( + "まぁ/副詞/マー 良い/形容詞/ヨイ だろう/助動詞/ダロー", + buf, +); ``` ## Feature flags @@ -31,7 +44,6 @@ The following features are disabled by default: * `train` - Enables the trainer. * `portable-simd` - Uses the [portable SIMD API](https://github.com/rust-lang/portable-simd) instead of our SIMD-conscious data layout. (Nightly Rust is required.) -* `charwise-daachorse` - Uses the [Charwise Daachorse](https://docs.rs/daachorse/latest/daachorse/charwise/index.html) instead of the standard version for faster prediction, although it can make to load a model file slower. The following features are enabled by default: @@ -39,6 +51,7 @@ The following features are enabled by default: * `cache-type-score` - Enables caching type scores for faster processing. If disabled, type scores are calculated in a straightforward manner. * `fix-weight-length` - Uses fixed-size arrays for storing scores to facilitate optimization. If disabled, vectors are used instead. * `tag-prediction` - Enables tag prediction. +* `charwise-pma` - Uses the [Charwise Daachorse](https://docs.rs/daachorse/latest/daachorse/charwise/index.html) instead of the standard version for faster prediction, although it can make to load a model file slower. ## License diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 1d5eb981..b9bb1e8e 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -1,529 +1,526 @@ -use alloc::string::String; -use alloc::vec::Vec; +mod boundary_scorer; #[cfg(feature = "tag-prediction")] -use core::iter; +mod boundary_tag_scorer; -#[cfg(feature = "tag-prediction")] -use alloc::sync::Arc; +use core::cell::RefCell; +use core::ops::AddAssign; -use bincode::{ - de::{BorrowDecoder, Decoder}, - enc::Encoder, - error::{DecodeError, EncodeError}, - BorrowDecode, Decode, Encode, -}; +use alloc::collections::BTreeMap; +use alloc::string::String; +use alloc::vec::Vec; -#[cfg(feature = "charwise-daachorse")] -use daachorse::charwise::CharwiseDoubleArrayAhoCorasick; -#[cfg(not(feature = "charwise-daachorse"))] -use daachorse::DoubleArrayAhoCorasick; +use bincode::{BorrowDecode, Encode}; use crate::dict_model::DictModel; -use crate::errors::{Result, VaporettoError}; +use crate::errors::Result; use crate::ngram_model::NgramModel; use crate::sentence::Sentence; -use crate::utils::{AddWeight, MergableWeight, WeightMerger}; #[cfg(feature = "tag-prediction")] -use crate::sentence::{TagRangeScore, TagRangeScores, TagScores}; -#[cfg(feature = "tag-prediction")] -use crate::utils; - -#[cfg(feature = "portable-simd")] -use core::simd::i32x8; +use crate::ngram_model::TagNgramModel; -pub const SIMD_SIZE: usize = 8; -#[cfg(feature = "portable-simd")] -type I32Vec = i32x8; +use boundary_scorer::CharScorerBoundary; -#[derive(Clone, Decode, Encode)] -struct PositionalWeight { - pub weight: W, - pub offset: i16, -} - -type NaivePositionalWeight = PositionalWeight>; +#[cfg(feature = "tag-prediction")] +use boundary_tag_scorer::CharScorerBoundaryTag; -impl NaivePositionalWeight { - fn new(offset: i16, weight: Vec) -> Self { - Self { offset, weight } - } +#[derive(Default)] +struct CharWeightMerger { + map: BTreeMap>, } -impl MergableWeight for NaivePositionalWeight { - fn from_two_weights(weight1: &Self, weight2: &Self, n_classes: usize) -> Self { - debug_assert!(n_classes != 0); - let (weight1, weight2) = if weight1.offset > weight2.offset { - (weight2, weight1) +impl CharWeightMerger +where + for<'a> W: AddAssign<&'a W>, +{ + pub fn add(&mut self, ngram: S, weight: W) + where + S: Into + AsRef, + { + if let Some(data) = self.map.get_mut(ngram.as_ref()) { + let (prev_weight, _) = &mut *data.borrow_mut(); + *prev_weight += &weight; } else { - (weight1, weight2) - }; - let shift = (weight2.offset - weight1.offset) as usize * n_classes; - let mut weight = vec![0; weight1.weight.len().max(shift + weight2.weight.len())]; - weight[..weight1.weight.len()].copy_from_slice(&weight1.weight); - for (r, w2) in weight[shift..].iter_mut().zip(&weight2.weight) { - *r += w2; + self.map.insert(ngram.into(), RefCell::new((weight, false))); } - Self { - offset: weight1.offset, - weight, - } - } -} - -#[derive(Clone)] -enum WeightVector { - Variable(Vec), - - #[cfg(all(feature = "fix-weight-length", not(feature = "portable-simd")))] - Fixed([i32; SIMD_SIZE]), - #[cfg(all(feature = "fix-weight-length", feature = "portable-simd"))] - Fixed(I32Vec), -} - -impl Decode for WeightVector { - fn decode(decoder: &mut D) -> Result { - let v: Vec = Decode::decode(decoder)?; - #[cfg(feature = "fix-weight-length")] - let result = if v.len() <= SIMD_SIZE { - let mut arr = [0; SIMD_SIZE]; - arr[..v.len()].copy_from_slice(&v); - - #[cfg(feature = "portable-simd")] - let arr = I32Vec::from_array(arr); - - Self::Fixed(arr) - } else { - Self::Variable(v) - }; - - #[cfg(not(feature = "fix-weight-length"))] - let result = Self::Variable(v); - - Ok(result) } -} -impl Encode for WeightVector { - fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { - match self { - Self::Variable(v) => { - Encode::encode(v, encoder)?; + #[must_use] + pub fn merge(self) -> Vec<(String, W)> { + let mut stack = vec![]; + for (ngram, data) in &self.map { + if data.borrow().1 { + continue; } - - #[cfg(feature = "fix-weight-length")] - Self::Fixed(v) => { - #[cfg(feature = "portable-simd")] - let v = &v.as_array(); - - let mut len = v.len(); - for (i, w) in v.iter().enumerate().rev() { - if *w != 0 { + stack.push(data); + for (j, _) in ngram.char_indices().skip(1) { + if let Some(data) = self.map.get(&ngram[j..]) { + stack.push(data); + if data.borrow().1 { break; } - len = i; } - Encode::encode(&v[..len].to_vec(), encoder)?; } - } - Ok(()) - } -} - -impl WeightVector { - pub fn new(weight: Vec) -> Self { - #[cfg(feature = "fix-weight-length")] - let v = if weight.len() <= SIMD_SIZE { - let mut arr = [0i32; SIMD_SIZE]; - arr[..weight.len()].copy_from_slice(weight.as_slice()); - - #[cfg(feature = "portable-simd")] - let arr = I32Vec::from_array(arr); - - Self::Fixed(arr) - } else { - Self::Variable(weight) - }; - - #[cfg(not(feature = "fix-weight-length"))] - let v = Self::Variable(weight); - - v - } - - fn add_weight(&self, ys: &mut [i32], offset: usize) { - match self { - WeightVector::Variable(weight) => { - weight.add_weight(ys, offset); - } - - #[cfg(feature = "fix-weight-length")] - WeightVector::Fixed(weight) => { - let ys_slice = &mut ys[offset..offset + SIMD_SIZE]; - #[cfg(feature = "portable-simd")] - { - let mut target = I32Vec::from_slice(ys_slice); - target += weight; - ys_slice.copy_from_slice(target.as_array()); - } - #[cfg(not(feature = "portable-simd"))] - for (y, w) in ys_slice.iter_mut().zip(weight) { - *y += w; - } + let mut data_from = stack.pop().unwrap(); + data_from.borrow_mut().1 = true; + while let Some(data_to) = stack.pop() { + let data_to_ref = &mut data_to.borrow_mut(); + data_to_ref.1 = true; + data_to_ref.0 += &data_from.borrow().0; + data_from = data_to; } } + self.map + .into_iter() + .map(|(ngram, weight)| (ngram, weight.into_inner().0)) + .collect() } } -#[cfg(feature = "tag-prediction")] -#[derive(Decode, Encode)] -pub struct WeightSet { - boundary: Option>, - tag_left: Option>>, - tag_right: Option>>, - tag_self: Option, -} +/// WARNING: Decoding is inherently unsafe. Do not publish this struct outside this +/// crate. +#[derive(BorrowDecode, Encode)] +pub enum CharScorer { + Boundary(CharScorerBoundary), -#[cfg(feature = "tag-prediction")] -type NaiveWeightSet = WeightSet>; + #[cfg(feature = "tag-prediction")] + BoundaryTag(CharScorerBoundaryTag), +} -#[cfg(feature = "tag-prediction")] -impl NaiveWeightSet { - fn boundary_weight(offset: i16, weight: Vec) -> Self { - Self { - boundary: Some(PositionalWeight::new(offset, weight)), - tag_left: None, - tag_right: None, - tag_self: None, +impl CharScorer { + pub fn new( + ngram_model: NgramModel, + dict_model: DictModel, + window_size: u8, + #[cfg(feature = "tag-prediction")] tag_ngram_model: Vec>, + ) -> Result> { + if ngram_model.0.is_empty() && dict_model.0.is_empty() || window_size == 0 { + return Ok(None); } - } - fn tag_left_weight(offset: i16, weight: Vec) -> Self { - Self { - boundary: None, - tag_left: Some(PositionalWeight::new(offset, weight)), - tag_right: None, - tag_self: None, + #[cfg(feature = "tag-prediction")] + if tag_ngram_model.is_empty() { + Ok(Some(Self::Boundary(CharScorerBoundary::new( + ngram_model, + dict_model, + window_size, + )?))) + } else { + Ok(Some(Self::BoundaryTag(CharScorerBoundaryTag::new( + ngram_model, + dict_model, + window_size, + tag_ngram_model, + )?))) } - } - fn tag_right_weight(offset: i16, weight: Vec) -> Self { - Self { - boundary: None, - tag_left: None, - tag_right: Some(PositionalWeight::new(offset, weight)), - tag_self: None, - } + #[cfg(not(feature = "tag-prediction"))] + Ok(Some(Self::Boundary(CharScorerBoundary::new( + ngram_model, + dict_model, + window_size, + )?))) } - fn tag_self_weight(start_rel_position: i16, weight: Vec) -> Self { - Self { - boundary: None, - tag_left: None, - tag_right: None, - tag_self: Some(Arc::new(vec![TagRangeScore::new( - start_rel_position, - weight, - )])), + #[inline] + pub fn add_scores<'a, 'b>(&self, sentence: &mut Sentence<'a, 'b>) { + match self { + Self::Boundary(scorer) => scorer.add_scores(sentence), + + #[cfg(feature = "tag-prediction")] + Self::BoundaryTag(scorer) => scorer.add_scores(sentence), } } -} -#[cfg(feature = "tag-prediction")] -impl MergableWeight for NaiveWeightSet { - fn from_two_weights(weight1: &Self, weight2: &Self, n_classes: usize) -> Self { - Self { - boundary: utils::xor_or_zip_with(&weight1.boundary, &weight2.boundary, |w1, w2| { - PositionalWeight::from_two_weights(w1, w2, 1) - }), - tag_left: utils::xor_or_zip_with(&weight1.tag_left, &weight2.tag_left, |w1, w2| { - PositionalWeight::from_two_weights(w1, w2, n_classes) - }), - tag_right: utils::xor_or_zip_with(&weight1.tag_right, &weight2.tag_right, |w1, w2| { - PositionalWeight::from_two_weights(w1, w2, n_classes) - }), - tag_self: utils::xor_or_zip_with(&weight1.tag_self, &weight2.tag_self, |w1, w2| { - let mut w = w1.to_vec(); - w.append(&mut w2.to_vec()); - Arc::new(w) - }), + /// # Satety + /// + /// `token_id` must be smaller than `scorer.tag_weight.len()`. + /// `pos` must be smaller than `sentence.char_pma_states.len()`. + #[cfg(feature = "tag-prediction")] + #[inline] + pub unsafe fn add_tag_scores( + &self, + token_id: u32, + pos: usize, + sentence: &Sentence, + scores: &mut [i32], + ) { + match self { + Self::Boundary(_) => panic!("unsupported"), + Self::BoundaryTag(scorer) => scorer.add_tag_scores(token_id, pos, sentence, scores), } } } -pub struct CharScorer { - #[cfg(feature = "charwise-daachorse")] - pma: CharwiseDoubleArrayAhoCorasick, - #[cfg(not(feature = "charwise-daachorse"))] - pma: DoubleArrayAhoCorasick, - weights: Vec>, -} - -impl CharScorer { - pub fn new(model: NgramModel, window_size: u8, dict: DictModel) -> Result { - let mut weight_merger = WeightMerger::new(1); - - for d in model.data { - let weight = PositionalWeight::new(-i16::from(window_size) - 1, d.weights); - weight_merger.add(&d.ngram, weight); - } - for d in dict.dict { - let word_len = d.word.chars().count(); - let word_len = i16::try_from(word_len).map_err(|_| { - VaporettoError::invalid_model( - "words must be shorter than or equal to 32767 characters", - ) - })?; - let weight = PositionalWeight::new(-word_len - 1, d.weights); - weight_merger.add(&d.word, weight); - } - - let mut ngrams = vec![]; - let mut weights = vec![]; - for (ngram, data) in weight_merger.merge() { - ngrams.push(ngram); - let PositionalWeight { offset, weight } = data; - weights.push(PositionalWeight { - offset, - weight: WeightVector::new(weight), - }); - } - #[cfg(feature = "charwise-daachorse")] - let pma = CharwiseDoubleArrayAhoCorasick::new(ngrams) - .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; - #[cfg(not(feature = "charwise-daachorse"))] - let pma = DoubleArrayAhoCorasick::new(ngrams) - .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; - Ok(Self { pma, weights }) +#[cfg(test)] +mod tests { + use super::*; + + use crate::dict_model::WordWeightRecord; + use crate::ngram_model::NgramData; + use crate::predictor::PositionalWeight; + + use crate::predictor::WEIGHT_FIXED_LEN; + + #[cfg(feature = "tag-prediction")] + use crate::ngram_model::{TagNgramData, TagWeight}; + + #[rustfmt::skip] + #[test] + fn test_weight_merger() { + let mut merger = CharWeightMerger::default(); + merger.add("東京都", PositionalWeight::new(-3, vec![1, 2, 3, 4])); + merger.add("京都", PositionalWeight::new(-3, vec![2, 4, 6, 8, 10])); + merger.add("京都", PositionalWeight::new(-2, vec![3, 6, 9])); + merger.add("大阪", PositionalWeight::new(-2, vec![4, 8, 12])); + assert_eq!( + vec![ + ("京都".into(), PositionalWeight::new(-3, vec![2, 7, 12, 17, 10])), + ("大阪".into(), PositionalWeight::new(-2, vec![4, 8, 12])), + ("東京都".into(), PositionalWeight::new(-3, vec![3, 9, 15, 21, 10])), + ], + merger.merge(), + ); } - #[allow(clippy::cast_possible_wrap)] - pub fn add_scores(&self, sentence: &Sentence, padding: u8, ys: &mut [i32]) { - // If the following assertion fails, Vaporetto has a bug. - assert_eq!(sentence.str_to_char_pos.len(), sentence.text.len() + 1); - - for m in self.pma.find_overlapping_no_suffix_iter(&sentence.text) { - // This was checked outside of the iteration. - let m_end = unsafe { *sentence.str_to_char_pos.get_unchecked(m.end()) }; - // Both the weights and the PMA always have the same number of items. - // Therefore, the following code is safe. - let pos_weights = unsafe { self.weights.get_unchecked(m.value()) }; - - let offset = isize::from(padding) + m_end as isize + isize::from(pos_weights.offset); - pos_weights.weight.add_weight(ys, offset as usize); - } + #[test] + fn test_add_scores_1() { + // input: 我 ら は 全 世 界 の 国 民 + // n-grams: + // 我ら: 3 4 5 + // 全世界: 6 7 8 9 + // 国民: 10 11 12 + // 世界: 15 16 17 18 19 + // 界: 20 21 22 23 24 25 + // dict: + // 全世界: 26 27 28 29 + // 世界: 30 31 32 + // 世: 33 34 + let scorer = CharScorerBoundary::new( + NgramModel(vec![ + NgramData { + ngram: "我ら".into(), + weights: vec![1, 2, 3, 4, 5], + }, + NgramData { + ngram: "全世界".into(), + weights: vec![6, 7, 8, 9], + }, + NgramData { + ngram: "国民".into(), + weights: vec![10, 11, 12, 13, 14], + }, + NgramData { + ngram: "世界".into(), + weights: vec![15, 16, 17, 18, 19], + }, + NgramData { + ngram: "界".into(), + weights: vec![20, 21, 22, 23, 24, 25], + }, + ]), + DictModel(vec![ + WordWeightRecord { + word: "全世界".into(), + weights: vec![26, 27, 28, 29], + comment: "".into(), + }, + WordWeightRecord { + word: "世界".into(), + weights: vec![30, 31, 32], + comment: "".into(), + }, + WordWeightRecord { + word: "世".into(), + weights: vec![33, 34], + comment: "".into(), + }, + ]), + 3, + ) + .unwrap(); + let mut sentence = Sentence::from_raw("我らは全世界の国民").unwrap(); + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, 1); + scorer.add_scores(&mut sentence); + assert_eq!( + &[4, 5, 73, 135, 141, 122, 55, 38], + sentence.boundary_scores(), + ); } -} -impl<'de> BorrowDecode<'de> for CharScorer { - /// WARNING: This function is inherently unsafe. Do not publish this function outside this - /// crate. - fn borrow_decode>(decoder: &mut D) -> Result { - let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; - #[cfg(feature = "charwise-daachorse")] - let (pma, _) = - unsafe { CharwiseDoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; - #[cfg(not(feature = "charwise-daachorse"))] - let (pma, _) = - unsafe { DoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; - Ok(Self { - pma, - weights: Decode::decode(decoder)?, - }) + #[test] + fn test_add_scores_2() { + // input: 我 ら は 全 世 界 の 国 民 + // n-grams: + // 我ら: 2 3 + // 全世界: 4 5 + // 国民: 6 7 + // 世界: 9 10 11 + // 界: 12 13 14 15 + // dict: + // 全世界: 16 17 18 19 + // 世界: 20 21 22 + // 世: 23 24 + let scorer = CharScorerBoundary::new( + NgramModel(vec![ + NgramData { + ngram: "我ら".into(), + weights: vec![1, 2, 3], + }, + NgramData { + ngram: "全世界".into(), + weights: vec![4, 5], + }, + NgramData { + ngram: "国民".into(), + weights: vec![6, 7, 8], + }, + NgramData { + ngram: "世界".into(), + weights: vec![9, 10, 11], + }, + NgramData { + ngram: "界".into(), + weights: vec![12, 13, 14, 15], + }, + ]), + DictModel(vec![ + WordWeightRecord { + word: "全世界".into(), + weights: vec![16, 17, 18, 19], + comment: "".into(), + }, + WordWeightRecord { + word: "世界".into(), + weights: vec![20, 21, 22], + comment: "".into(), + }, + WordWeightRecord { + word: "世".into(), + weights: vec![23, 24], + comment: "".into(), + }, + ]), + 2, + ) + .unwrap(); + let mut sentence = Sentence::from_raw("我らは全世界の国民").unwrap(); + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, 2); + scorer.add_scores(&mut sentence); + assert_eq!(&[4, 5, 18, 87, 93, 68, 23, 9], sentence.boundary_scores(),); } -} -impl Encode for CharScorer { - fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { - let pma_data = self.pma.serialize_to_vec(); - Encode::encode(&pma_data, encoder)?; - Encode::encode(&self.weights, encoder)?; - Ok(()) + #[test] + fn test_add_scores_3() { + // input: 我 ら は 全 世 界 の 国 民 + // n-grams: + // 我ら: 3 4 5 + // 全世界: 6 7 8 9 + // 国民: 10 11 12 + // 世界: 15 16 17 18 19 + // 界: 20 21 22 23 24 25 + // dict: + // 全世界: 26 27 28 29 + // 世界: 30 31 32 + // 世: 33 34 + // 世界の国民: 35 36 37 38 39 + // は全世界: 41 42 43 44 45 + let scorer = CharScorerBoundary::new( + NgramModel(vec![ + NgramData { + ngram: "我ら".into(), + weights: vec![1, 2, 3, 4, 5], + }, + NgramData { + ngram: "全世界".into(), + weights: vec![6, 7, 8, 9], + }, + NgramData { + ngram: "国民".into(), + weights: vec![10, 11, 12, 13, 14], + }, + NgramData { + ngram: "世界".into(), + weights: vec![15, 16, 17, 18, 19], + }, + NgramData { + ngram: "界".into(), + weights: vec![20, 21, 22, 23, 24, 25], + }, + ]), + DictModel(vec![ + WordWeightRecord { + word: "全世界".into(), + weights: vec![26, 27, 28, 29], + comment: "".into(), + }, + WordWeightRecord { + word: "世界".into(), + weights: vec![30, 31, 32], + comment: "".into(), + }, + WordWeightRecord { + word: "世".into(), + weights: vec![33, 34], + comment: "".into(), + }, + WordWeightRecord { + word: "世界の国民".into(), + weights: vec![35, 36, 37, 38, 39, 40], + comment: "".into(), + }, + WordWeightRecord { + word: "は全世界".into(), + weights: vec![41, 42, 43, 44, 45], + comment: "".into(), + }, + ]), + 3, + ) + .unwrap(); + let mut sentence = Sentence::from_raw("我らは全世界の国民").unwrap(); + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, 3); + scorer.add_scores(&mut sentence); + assert_eq!( + &[6, 48, 117, 215, 223, 206, 95, 79], + sentence.boundary_scores(), + ); } -} - -#[cfg(feature = "tag-prediction")] -pub struct CharScorerWithTags { - #[cfg(feature = "charwise-daachorse")] - pma: CharwiseDoubleArrayAhoCorasick, - #[cfg(not(feature = "charwise-daachorse"))] - pma: DoubleArrayAhoCorasick, - weights: Vec>, - n_tags: usize, -} -#[cfg(feature = "tag-prediction")] -impl CharScorerWithTags { - pub fn new( - model: NgramModel, - window_size: u8, - dict: DictModel, - n_tags: usize, - tag_left_model: NgramModel, - tag_right_model: NgramModel, - tag_self_model: NgramModel, - ) -> Result { - let mut weight_merger = WeightMerger::new(n_tags); - - for d in model.data { - let weight = WeightSet::boundary_weight(-i16::from(window_size), d.weights); - weight_merger.add(&d.ngram, weight); - } - for d in dict.dict { - let word_len = d.word.chars().count(); - let word_len = i16::try_from(word_len).map_err(|_| { - VaporettoError::invalid_model( - "words must be shorter than or equal to 32767 characters", - ) - })?; - let weight = WeightSet::boundary_weight(-word_len, d.weights); - weight_merger.add(&d.word, weight); - } - for d in tag_left_model.data { - let ngram_len = i16::try_from(d.ngram.chars().count()).map_err(|_| { - VaporettoError::invalid_model( - "character n-grams must be shorter than or equal to 32767 characters", - ) - })?; - let weight = WeightSet::tag_left_weight(-ngram_len + 1, d.weights); - weight_merger.add(&d.ngram, weight); - } - for d in tag_right_model.data { - let weight = WeightSet::tag_right_weight(-i16::from(window_size) - 1, d.weights); - weight_merger.add(&d.ngram, weight); - } - for d in tag_self_model.data { - let ngram_len = i16::try_from(d.ngram.chars().count()).map_err(|_| { - VaporettoError::invalid_model( - "character n-grams must be shorter than or equal to 32767 characters", - ) - })?; - let weight = WeightSet::tag_self_weight(-ngram_len, d.weights); - weight_merger.add(&d.ngram, weight); + #[cfg(feature = "tag-prediction")] + #[test] + fn test_add_scores_with_tags() { + // input: こ の 人 は 火 星 人 だ + // n-grams: + // この人: 2 3 4 + // 人だ: 5 6 7 + // dict: + // 人: 10 11 10 11 + // 火星: 12 13 14 + let scorer = CharScorerBoundaryTag::new( + NgramModel(vec![ + NgramData { + ngram: "この人".into(), + weights: vec![1, 2, 3, 4], + }, + NgramData { + ngram: "人だ".into(), + weights: vec![5, 6, 7, 8, 9], + }, + ]), + DictModel(vec![ + WordWeightRecord { + word: "人".into(), + weights: vec![10, 11], + comment: "".into(), + }, + WordWeightRecord { + word: "火星".into(), + weights: vec![12, 13, 14], + comment: "".into(), + }, + ]), + 3, + vec![ + TagNgramModel(vec![ + TagNgramData { + ngram: "の人".into(), + weights: vec![ + TagWeight { + rel_position: 0, + weights: vec![15, 16, 17], + }, + TagWeight { + rel_position: 1, + weights: vec![18, 19, 20], + }, + ], + }, + TagNgramData { + ngram: "人は".into(), + weights: vec![ + TagWeight { + rel_position: 1, + weights: vec![21, 22, 23], + }, + TagWeight { + rel_position: 3, + weights: vec![24, 25, 26], + }, + ], + }, + TagNgramData { + ngram: "火星人".into(), + weights: vec![TagWeight { + rel_position: 0, + weights: vec![27, 28, 29], + }], + }, + ]), + TagNgramModel(vec![]), + TagNgramModel(vec![ + TagNgramData { + ngram: "人は".into(), + weights: vec![ + TagWeight { + rel_position: 0, + weights: vec![27, 28], + }, + TagWeight { + rel_position: 3, + weights: vec![29, 30], + }, + ], + }, + TagNgramData { + ngram: "は火星人".into(), + weights: vec![TagWeight { + rel_position: 3, + weights: vec![31, 32], + }], + }, + ]), + ], + ) + .unwrap(); + let mut sentence = Sentence::from_raw("この人は火星人だ").unwrap(); + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, 1); + scorer.add_scores(&mut sentence); + assert_eq!(&[3, 14, 16, 13, 19, 31, 19], sentence.boundary_scores()); + + let mut tag_scores = [1; 8]; + unsafe { + scorer.add_tag_scores(0, 2, &sentence, &mut tag_scores); } + assert_eq!(&[37, 39, 41, 1, 1, 1, 1, 1], &tag_scores); - let mut ngrams = vec![]; - let mut weights = vec![]; - for (ngram, data) in weight_merger.merge() { - ngrams.push(ngram); - let WeightSet { - boundary, - tag_left, - tag_right, - tag_self, - } = data; - weights.push(WeightSet { - boundary: boundary.map(|PositionalWeight { offset, weight }| PositionalWeight { - offset, - weight: WeightVector::new(weight), - }), - tag_left, - tag_right, - tag_self, - }); + let mut tag_scores = [1; 8]; + unsafe { + scorer.add_tag_scores(0, 6, &sentence, &mut tag_scores); } - #[cfg(feature = "charwise-daachorse")] - let pma = CharwiseDoubleArrayAhoCorasick::new(ngrams) - .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; - #[cfg(not(feature = "charwise-daachorse"))] - let pma = DoubleArrayAhoCorasick::new(ngrams) - .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; - Ok(Self { - pma, - weights, - n_tags, - }) - } - - #[allow(clippy::cast_possible_wrap)] - pub fn add_scores( - &self, - sentence: &Sentence, - padding: u8, - ys: &mut [i32], - tag_ys: &mut TagScores, - ) { - #[cfg(not(feature = "charwise-daachorse"))] - let no_suffix_iter = self.pma.find_overlapping_no_suffix_iter_from_iter( - iter::once(0) - .chain(sentence.text.as_bytes().iter().cloned()) - .chain(iter::once(0)), - ); - // Since `sentence.text` is a valid UTF-8 string ensured by type `String`, - // the following code is safe. - #[cfg(feature = "charwise-daachorse")] - let no_suffix_iter = unsafe { - self.pma.find_overlapping_no_suffix_iter_from_iter( - iter::once(0) - .chain(sentence.text.as_bytes().iter().cloned()) - .chain(iter::once(0)), - ) - }; - for m in no_suffix_iter { - let m_end = sentence - .str_to_char_pos - .get(m.end() - 1) - .copied() - .unwrap_or(sentence.chars.len() + 1); + assert_eq!(&[28, 29, 30, 1, 1, 1, 1, 1], &tag_scores); - // Both the weights and the PMA always have the same number of items. - // Therefore, the following code is safe. - let weight_set = unsafe { self.weights.get_unchecked(m.value()) }; - - if let Some(pos_weights) = weight_set.boundary.as_ref() { - let offset = - isize::from(padding) + m_end as isize + isize::from(pos_weights.offset) - 1; - pos_weights.weight.add_weight(ys, offset as usize); - } - if let Some(pos_weights) = weight_set.tag_left.as_ref() { - let offset = - (m_end as isize + isize::from(pos_weights.offset)) * self.n_tags as isize; - pos_weights - .weight - .add_weight_signed(&mut tag_ys.left_scores, offset); - } - if let Some(pos_weights) = weight_set.tag_right.as_ref() { - let offset = - (m_end as isize + isize::from(pos_weights.offset)) * self.n_tags as isize; - pos_weights - .weight - .add_weight_signed(&mut tag_ys.right_scores, offset); - } - if let Some(weight) = weight_set.tag_self.as_ref() { - tag_ys.self_scores[m_end - 1].replace(Arc::clone(weight)); - } + let mut tag_scores = [1; 8]; + unsafe { + scorer.add_tag_scores(2, 3, &sentence, &mut tag_scores); } - } -} - -#[cfg(feature = "tag-prediction")] -impl<'de> BorrowDecode<'de> for CharScorerWithTags { - /// WARNING: This function is inherently unsafe. Do not publish this function outside this - /// crate. - fn borrow_decode>(decoder: &mut D) -> Result { - let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; - #[cfg(feature = "charwise-daachorse")] - let (pma, _) = - unsafe { CharwiseDoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; - #[cfg(not(feature = "charwise-daachorse"))] - let (pma, _) = - unsafe { DoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; - Ok(Self { - pma, - weights: Decode::decode(decoder)?, - n_tags: Decode::decode(decoder)?, - }) - } -} - -#[cfg(feature = "tag-prediction")] -impl Encode for CharScorerWithTags { - fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { - let pma_data = self.pma.serialize_to_vec(); - Encode::encode(&pma_data, encoder)?; - Encode::encode(&self.weights, encoder)?; - Encode::encode(&self.n_tags, encoder)?; - Ok(()) + assert_eq!(&[59, 61, 1, 1, 1, 1, 1, 1], &tag_scores); } } diff --git a/vaporetto/src/char_scorer/boundary_scorer.rs b/vaporetto/src/char_scorer/boundary_scorer.rs new file mode 100644 index 00000000..a2e9bd82 --- /dev/null +++ b/vaporetto/src/char_scorer/boundary_scorer.rs @@ -0,0 +1,113 @@ +use alloc::string::String; +use alloc::vec::Vec; + +use bincode::{ + de::BorrowDecoder, + enc::Encoder, + error::{DecodeError, EncodeError}, + BorrowDecode, Decode, Encode, +}; +#[cfg(feature = "charwise-pma")] +use daachorse::charwise::CharwiseDoubleArrayAhoCorasick; +#[cfg(not(feature = "charwise-pma"))] +use daachorse::DoubleArrayAhoCorasick; + +use crate::char_scorer::CharWeightMerger; +use crate::dict_model::DictModel; +use crate::errors::{Result, VaporettoError}; +use crate::ngram_model::NgramModel; +use crate::predictor::{PositionalWeight, WeightVector}; +use crate::sentence::Sentence; + +pub struct CharScorerBoundary { + #[cfg(not(feature = "charwise-pma"))] + pma: DoubleArrayAhoCorasick, + #[cfg(feature = "charwise-pma")] + pma: CharwiseDoubleArrayAhoCorasick, + weights: Vec>, +} + +impl<'de> BorrowDecode<'de> for CharScorerBoundary { + /// WARNING: This function is inherently unsafe. Do not publish this function outside this + /// crate. + fn borrow_decode>(decoder: &mut D) -> Result { + let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; + #[cfg(not(feature = "charwise-pma"))] + let (pma, _) = + unsafe { DoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; + #[cfg(feature = "charwise-pma")] + let (pma, _) = + unsafe { CharwiseDoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; + Ok(Self { + pma, + weights: Decode::decode(decoder)?, + }) + } +} + +impl Encode for CharScorerBoundary { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let pma_data = self.pma.serialize_to_vec(); + Encode::encode(&pma_data, encoder)?; + Encode::encode(&self.weights, encoder)?; + Ok(()) + } +} + +impl CharScorerBoundary { + pub fn new( + ngram_model: NgramModel, + dict_model: DictModel, + window_size: u8, + ) -> Result { + let mut merger = CharWeightMerger::default(); + for d in ngram_model.0 { + let weight = PositionalWeight::new(-i16::from(window_size), d.weights); + merger.add(d.ngram, weight); + } + for d in dict_model.0 { + let word_len = d.word.chars().count(); + let word_len = i16::try_from(word_len).map_err(|_| { + VaporettoError::invalid_model( + "words must be shorter than or equal to 32767 characters", + ) + })?; + let weight = PositionalWeight::new(-word_len, d.weights); + merger.add(d.word, weight); + } + let mut ngrams = vec![]; + let mut weights = vec![]; + for (ngram, weight) in merger.merge() { + ngrams.push(ngram); + weights.push(weight.into()); + } + #[cfg(not(feature = "charwise-pma"))] + let pma = DoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; + #[cfg(feature = "charwise-pma")] + let pma = CharwiseDoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; + Ok(Self { pma, weights }) + } + + #[allow(clippy::cast_possible_wrap)] + #[inline(always)] + pub fn add_scores<'a, 'b>(&self, sentence: &mut Sentence<'a, 'b>) { + #[cfg(not(feature = "charwise-pma"))] + let it = self + .pma + .find_overlapping_no_suffix_iter(sentence.text.as_bytes()); + #[cfg(feature = "charwise-pma")] + let it = self.pma.find_overlapping_no_suffix_iter(&sentence.text); + for m in it { + debug_assert!(m.end() != 0 && sentence.text.is_char_boundary(m.end())); + let end = unsafe { sentence.str_to_char_pos(m.end()) }; + debug_assert!(m.value() < self.weights.len()); + let weight = unsafe { self.weights.get_unchecked(m.value()) }; + weight.add_score( + (end + sentence.score_padding - 1) as isize, + &mut sentence.boundary_scores, + ); + } + } +} diff --git a/vaporetto/src/char_scorer/boundary_tag_scorer.rs b/vaporetto/src/char_scorer/boundary_tag_scorer.rs new file mode 100644 index 00000000..709447e8 --- /dev/null +++ b/vaporetto/src/char_scorer/boundary_tag_scorer.rs @@ -0,0 +1,185 @@ +use alloc::string::String; +use alloc::vec::Vec; + +use bincode::{ + de::BorrowDecoder, + enc::Encoder, + error::{DecodeError, EncodeError}, + BorrowDecode, Decode, Encode, +}; +#[cfg(feature = "charwise-pma")] +use daachorse::charwise::CharwiseDoubleArrayAhoCorasick; +#[cfg(not(feature = "charwise-pma"))] +use daachorse::DoubleArrayAhoCorasick; +use hashbrown::HashMap; + +use crate::char_scorer::CharWeightMerger; +use crate::dict_model::DictModel; +use crate::errors::{Result, VaporettoError}; +use crate::ngram_model::{NgramModel, TagNgramModel}; +use crate::predictor::{PositionalWeight, PositionalWeightWithTag, WeightVector}; +use crate::sentence::Sentence; +use crate::utils::SplitMix64Builder; + +pub struct CharScorerBoundaryTag { + #[cfg(not(feature = "charwise-pma"))] + pma: DoubleArrayAhoCorasick, + #[cfg(feature = "charwise-pma")] + pma: CharwiseDoubleArrayAhoCorasick, + weights: Vec>>, + tag_weight: Vec>>, +} + +impl<'de> BorrowDecode<'de> for CharScorerBoundaryTag { + /// WARNING: This function is inherently unsafe. Do not publish this function outside this + /// crate. + fn borrow_decode>(decoder: &mut D) -> Result { + let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; + #[cfg(not(feature = "charwise-pma"))] + let (pma, _) = + unsafe { DoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; + #[cfg(feature = "charwise-pma")] + let (pma, _) = + unsafe { CharwiseDoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; + let tag_weight: Vec>> = Decode::decode(decoder)?; + let tag_weight = tag_weight + .into_iter() + .map(|x| x.into_iter().map(|x| x.into_iter().collect()).collect()) + .collect(); + Ok(Self { + pma, + weights: Decode::decode(decoder)?, + tag_weight, + }) + } +} + +impl Encode for CharScorerBoundaryTag { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let pma_data = self.pma.serialize_to_vec(); + Encode::encode(&pma_data, encoder)?; + Encode::encode(&self.weights, encoder)?; + let tag_weight: Vec>> = self + .tag_weight + .iter() + .map(|x| x.iter().map(|x| x.iter().collect()).collect()) + .collect(); + Encode::encode(&tag_weight, encoder)?; + Ok(()) + } +} + +impl CharScorerBoundaryTag { + pub fn new( + ngram_model: NgramModel, + dict_model: DictModel, + window_size: u8, + tag_ngram_model: Vec>, + ) -> Result { + let mut merger = CharWeightMerger::default(); + for d in ngram_model.0 { + let weight = PositionalWeightWithTag::with_boundary(-i16::from(window_size), d.weights); + merger.add(d.ngram, weight); + } + for d in dict_model.0 { + let word_len = d.word.chars().count(); + let word_len = i16::try_from(word_len).map_err(|_| { + VaporettoError::invalid_model( + "words must be shorter than or equal to 32767 characters", + ) + })?; + let weight = PositionalWeightWithTag::with_boundary(-word_len, d.weights); + merger.add(d.word, weight); + } + let mut tag_weight = + vec![ + vec![HashMap::with_hasher(SplitMix64Builder); usize::from(window_size) + 1]; + tag_ngram_model.len() + ]; + for (i, tag_model) in tag_ngram_model.into_iter().enumerate() { + for d in tag_model.0 { + for w in d.weights { + let weight = PositionalWeightWithTag::with_tag(i, w.rel_position, w.weights); + merger.add(&d.ngram, weight); + } + } + } + let mut ngrams = vec![]; + let mut weights = vec![]; + for (i, (ngram, weight)) in merger.merge().into_iter().enumerate() { + ngrams.push(ngram); + weights.push(weight.weight.map(|w| w.into())); + for ((token_id, rel_position), weight) in weight.tag_info { + tag_weight[token_id][usize::from(rel_position)] + .insert(u32::try_from(i).unwrap(), weight.into()); + } + } + #[cfg(not(feature = "charwise-pma"))] + let pma = DoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; + #[cfg(feature = "charwise-pma")] + let pma = CharwiseDoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; + Ok(Self { + pma, + weights, + tag_weight, + }) + } + + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_possible_wrap)] + #[inline(always)] + pub fn add_scores<'a, 'b>(&self, sentence: &mut Sentence<'a, 'b>) { + sentence.char_pma_states.clear(); + sentence.char_pma_states.resize(sentence.len(), u32::MAX); + #[cfg(not(feature = "charwise-pma"))] + let it = self + .pma + .find_overlapping_no_suffix_iter(sentence.text.as_bytes()); + #[cfg(feature = "charwise-pma")] + let it = self.pma.find_overlapping_no_suffix_iter(&sentence.text); + for m in it { + debug_assert!(m.end() != 0 && sentence.text.is_char_boundary(m.end())); + let end = unsafe { sentence.str_to_char_pos(m.end()) }; + debug_assert!(m.value() < self.weights.len()); + if let Some(weight) = unsafe { self.weights.get_unchecked(m.value()).as_ref() } { + weight.add_score( + (end + sentence.score_padding - 1) as isize, + &mut sentence.boundary_scores, + ); + } + debug_assert!(end as usize <= sentence.char_pma_states.len()); + unsafe { + *sentence.char_pma_states.get_unchecked_mut(end as usize - 1) = m.value() as u32 + }; + } + } + + /// # Satety + /// + /// `token_id` must be smaller than `scorer.tag_weight.len()`. + /// `pos` must be smaller than `sentence.char_pma_states.len()`. + #[inline(always)] + pub unsafe fn add_tag_scores( + &self, + token_id: u32, + pos: usize, + sentence: &Sentence, + scores: &mut [i32], + ) { + let tag_weight = self + .tag_weight + .get_unchecked(usize::try_from(token_id).unwrap()); + for (state_id, tag_weights) in sentence + .char_pma_states + .get_unchecked(pos..) + .iter() + .zip(tag_weight) + { + if let Some(weight) = tag_weights.get(state_id) { + weight.add_scores(scores); + } + } + } +} diff --git a/vaporetto/src/dict_model.rs b/vaporetto/src/dict_model.rs index 65cbb030..23b43a21 100644 --- a/vaporetto/src/dict_model.rs +++ b/vaporetto/src/dict_model.rs @@ -13,7 +13,7 @@ pub struct DictWeight { } /// Record of weights for each word. -#[derive(Clone, Decode, Encode)] +#[derive(Clone, Debug, Decode, Encode)] pub struct WordWeightRecord { pub(crate) word: String, pub(crate) weights: Vec, @@ -26,12 +26,13 @@ impl WordWeightRecord { /// # Arguments /// /// * `word` - A word. - /// * `weights` - A weight of boundaries. + /// * `weights` - Weights of each character boundary. /// * `comment` - A comment that does not affect the behaviour. /// - /// # Returns + /// # Errors /// - /// A new record. + /// If `weights.len() != word.chars().count() + 1`, + /// an error variant will be returned. pub fn new(word: String, weights: Vec, comment: String) -> Result { if weights.len() != word.chars().count() + 1 { return Err(VaporettoError::invalid_argument( @@ -62,17 +63,15 @@ impl WordWeightRecord { } } -#[derive(Decode, Encode)] -pub struct DictModel { - pub(crate) dict: Vec, -} +#[derive(Debug, Decode, Encode)] +pub struct DictModel(pub(crate) Vec); impl DictModel { pub fn new(dict: Vec) -> Self { - Self { dict } + Self(dict) } pub fn dictionary(&self) -> &[WordWeightRecord] { - &self.dict + &self.0 } } diff --git a/vaporetto/src/errors.rs b/vaporetto/src/errors.rs index 5206a6fc..e9d7dac0 100644 --- a/vaporetto/src/errors.rs +++ b/vaporetto/src/errors.rs @@ -7,19 +7,33 @@ use alloc::string::String; #[cfg(feature = "std")] use std::error::Error; +/// A specialized Result type for Vaporetto. pub type Result = core::result::Result; +/// The error type for Vaporetto. #[derive(Debug)] pub enum VaporettoError { + /// The error variant for [`InvalidModelError`]. InvalidModel(InvalidModelError), - InvalidSentence(InvalidSentenceError), + + /// The error variant for [`InvalidArgumentError`]. InvalidArgument(InvalidArgumentError), + + /// The error variant for [`FromUtf8Error`](alloc::string::FromUtf8Error). UTF8Error(alloc::string::FromUtf8Error), + + /// The error variant for [`TryFromIntError`](core::num::TryFromIntError). CastError(core::num::TryFromIntError), + + /// The error variant for [`DecodeError`](bincode::error::DecodeError). DecodeError(bincode::error::DecodeError), + + /// The error variant for [`EncodeError`](bincode::error::EncodeError). EncodeError(bincode::error::EncodeError), + /// The error variant for [`std::io::Error`]. #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] IOError(std::io::Error), } @@ -31,13 +45,6 @@ impl VaporettoError { Self::InvalidModel(InvalidModelError { msg: msg.into() }) } - pub(crate) fn invalid_sentence(msg: S) -> Self - where - S: Into, - { - Self::InvalidSentence(InvalidSentenceError { msg: msg.into() }) - } - pub(crate) fn invalid_argument(arg: &'static str, msg: S) -> Self where S: Into, @@ -53,7 +60,6 @@ impl fmt::Display for VaporettoError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::InvalidModel(e) => e.fmt(f), - Self::InvalidSentence(e) => e.fmt(f), Self::InvalidArgument(e) => e.fmt(f), Self::UTF8Error(e) => e.fmt(f), Self::CastError(e) => e.fmt(f), @@ -85,22 +91,6 @@ impl fmt::Display for InvalidModelError { #[cfg(feature = "std")] impl Error for InvalidModelError {} -/// Error used when the sentence is invalid. -#[derive(Debug)] -pub struct InvalidSentenceError { - /// Error message. - pub(crate) msg: String, -} - -impl fmt::Display for InvalidSentenceError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "InvalidSentenceError: {}", self.msg) - } -} - -#[cfg(feature = "std")] -impl Error for InvalidSentenceError {} - /// Error used when the argument is invalid. #[derive(Debug)] pub struct InvalidArgumentError { diff --git a/vaporetto/src/feature.rs b/vaporetto/src/feature.rs deleted file mode 100644 index e49ce55c..00000000 --- a/vaporetto/src/feature.rs +++ /dev/null @@ -1,769 +0,0 @@ -use std::hash::Hash; -use std::sync::Arc; - -use daachorse::DoubleArrayAhoCorasick; - -use crate::errors::{Result, VaporettoError}; -use crate::sentence::BoundaryType; -use crate::sentence::Sentence; - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] -pub struct StringNgramFeature<'a> { - pub(crate) rel_position: isize, - pub(crate) ngram: &'a str, -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct BytesNgramFeature<'a> { - pub(crate) rel_position: isize, - pub(crate) ngram: &'a [u8], -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub enum DictionaryWordPosition { - Right, - Left, - Inside, -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct DictionaryWordFeature { - pub(crate) position: DictionaryWordPosition, - pub(crate) length: usize, -} - -#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)] -pub enum BoundaryFeature<'a> { - CharacterNgram(StringNgramFeature<'a>), - CharacterTypeNgram(BytesNgramFeature<'a>), - DictionaryWord(DictionaryWordFeature), -} - -impl<'a> BoundaryFeature<'a> { - pub const fn char_ngram(rel_position: isize, ngram: &'a str) -> Self { - Self::CharacterNgram(StringNgramFeature { - rel_position, - ngram, - }) - } - - pub const fn type_ngram(rel_position: isize, ngram: &'a [u8]) -> Self { - Self::CharacterTypeNgram(BytesNgramFeature { - rel_position, - ngram, - }) - } - - pub const fn dict_word(position: DictionaryWordPosition, length: usize) -> Self { - Self::DictionaryWord(DictionaryWordFeature { position, length }) - } -} - -#[derive(Debug, PartialEq)] -pub struct BoundaryExample<'a> { - pub features: Vec>, - pub label: BoundaryType, -} - -pub struct BoundaryExampleGenerator { - char_ngram_size: u8, - type_ngram_size: u8, - char_window_size: u8, - type_window_size: u8, - dict_ac: Option, - dict_max_word_size: u8, -} - -impl BoundaryExampleGenerator { - pub fn new( - char_ngram_size: u8, - type_ngram_size: u8, - char_window_size: u8, - type_window_size: u8, - dict: Option, - dict_max_word_size: u8, - ) -> Result - where - I: IntoIterator, - P: AsRef<[u8]>, - { - let dict_ac = if let Some(dict) = dict { - Some( - DoubleArrayAhoCorasick::new(dict) - .map_err(|e| VaporettoError::invalid_argument("dict", format!("{:?}", e)))?, - ) - } else { - None - }; - Ok(Self { - char_ngram_size, - type_ngram_size, - char_window_size, - type_window_size, - dict_ac, - dict_max_word_size, - }) - } - - pub fn generate<'a>(&self, s: &'a Sentence) -> Result>> { - let mut result = vec![]; - for (i, &label) in s.boundaries().iter().enumerate() { - let mut features = vec![]; - for n in 0..usize::from(self.char_ngram_size) { - let begin = (i + 1).saturating_sub(usize::from(self.char_window_size)); - let end = (i + 1 + usize::from(self.char_window_size)) - .min(s.chars.len()) - .saturating_sub(n); - for pos in begin..end { - let rel_position = isize::try_from(pos)? - isize::try_from(i)? - 1; - let ngram = s.char_substring(pos, pos + n + 1); - features.push(BoundaryFeature::char_ngram(rel_position, ngram)); - } - } - for n in 0..usize::from(self.type_ngram_size) { - let begin = (i + 1).saturating_sub(usize::from(self.type_window_size)); - let end = (i + 1 + usize::from(self.type_window_size)) - .min(s.chars.len()) - .saturating_sub(n); - for pos in begin..end { - let rel_position = isize::try_from(pos)? - isize::try_from(i)? - 1; - let ngram = &s.char_types()[pos..pos + n + 1]; - features.push(BoundaryFeature::type_ngram(rel_position, ngram)); - } - } - result.push(BoundaryExample { features, label }) - } - if let Some(dict_ac) = self.dict_ac.as_ref() { - for m in dict_ac.find_overlapping_iter(&s.text) { - let m_start = s.str_to_char_pos[m.start()]; - let m_end = s.str_to_char_pos[m.end()]; - let length = (m_end - m_start).min(usize::from(self.dict_max_word_size)); - if m_start != 0 { - result[m_start - 1] - .features - .push(BoundaryFeature::dict_word( - DictionaryWordPosition::Right, - length, - )); - } - for example in &mut result[m_start..m_end - 1] { - example.features.push(BoundaryFeature::dict_word( - DictionaryWordPosition::Inside, - length, - )); - } - if m_end != s.chars().len() { - result[m_end - 1].features.push(BoundaryFeature::dict_word( - DictionaryWordPosition::Left, - length, - )); - } - } - } - Ok(result - .into_iter() - .filter(|example| example.label != BoundaryType::Unknown) - .collect()) - } -} - -#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum TagFeature<'a> { - LeftCharacterNgram(StringNgramFeature<'a>), - LeftCharacterNgramBos(StringNgramFeature<'a>), - RightCharacterNgram(StringNgramFeature<'a>), - RightCharacterNgramEos(StringNgramFeature<'a>), - Character(&'a str), -} - -impl<'a> TagFeature<'a> { - pub const fn left_char_ngram(rel_position: isize, ngram: &'a str) -> Self { - Self::LeftCharacterNgram(StringNgramFeature { - rel_position, - ngram, - }) - } - - pub const fn left_char_ngram_bos(rel_position: isize, ngram: &'a str) -> Self { - Self::LeftCharacterNgramBos(StringNgramFeature { - rel_position, - ngram, - }) - } - - pub const fn right_char_ngram(rel_position: isize, ngram: &'a str) -> Self { - Self::RightCharacterNgram(StringNgramFeature { - rel_position, - ngram, - }) - } - - pub const fn right_char_ngram_eos(rel_position: isize, ngram: &'a str) -> Self { - Self::RightCharacterNgramEos(StringNgramFeature { - rel_position, - ngram, - }) - } - - pub const fn chars(chars: &'a str) -> Self { - Self::Character(chars) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct TagExample<'a> { - pub features: Vec>, - pub tag: Arc, -} - -pub struct TagExampleGenerator { - char_ngram_size: u8, - char_window_size: u8, -} - -impl TagExampleGenerator { - pub const fn new(char_ngram_size: u8, char_window_size: u8) -> Self { - Self { - char_ngram_size, - char_window_size, - } - } - - pub fn generate<'a>(&self, sentence: &'a Sentence) -> Result>> { - let mut result = vec![]; - let mut features = vec![]; - for start in (sentence.chars.len() + 1).saturating_sub(usize::from(self.char_ngram_size)) - ..sentence.chars.len() + 1 - { - features.push(TagFeature::right_char_ngram_eos( - 1, - sentence.char_substring(start, sentence.chars.len()), - )); - } - let mut current_tag: Option> = sentence - .tags - .last() - .and_then(|x| x.as_ref()) - .map(Arc::clone); - let mut tag_right_pos = sentence.chars.len(); - for (i, (t, b)) in sentence - .tags - .iter() - .zip(sentence.boundaries()) - .enumerate() - .rev() - { - match b { - BoundaryType::WordBoundary => { - if let Some(tag) = current_tag.take() { - if i + 2 <= usize::from(self.char_window_size) { - let rel_position = -isize::try_from(i)? - 2; - for end in - 0..sentence.chars.len().min(usize::from(self.char_ngram_size)) - { - features.push(TagFeature::left_char_ngram_bos( - rel_position, - sentence.char_substring(0, end), - )); - } - } - for j in (i + 1).saturating_sub(usize::from(self.char_window_size))..i + 1 { - let rel_position = isize::try_from(j)? - isize::try_from(i)? - 1; - for end in j + 1 - ..sentence - .chars - .len() - .min(j + usize::from(self.char_ngram_size)) - + 1 - { - features.push(TagFeature::left_char_ngram( - rel_position, - sentence.char_substring(j, end), - )); - } - } - features.push(TagFeature::chars( - sentence.char_substring(i + 1, tag_right_pos), - )); - result.push(TagExample { features, tag }); - features = vec![]; - } - if let Some(tag) = t.as_ref() { - current_tag.replace(Arc::clone(tag)); - tag_right_pos = i + 1; - for j in (i + 2) - ..(i + 2 + usize::from(self.char_window_size)) - .min(sentence.chars.len() + 1) - { - let rel_position = isize::try_from(j - i)? - 1; - for start in j.saturating_sub(usize::from(self.char_ngram_size))..j { - features.push(TagFeature::right_char_ngram( - rel_position, - sentence.char_substring(start, j), - )); - } - } - if i + usize::from(self.char_window_size) >= sentence.chars.len() { - let rel_position = isize::try_from(sentence.chars.len() - i)?; - for start in (sentence.chars.len() + 1) - .saturating_sub(usize::from(self.char_ngram_size)) - ..sentence.chars.len() + 1 - { - features.push(TagFeature::right_char_ngram_eos( - rel_position, - sentence.char_substring(start, sentence.chars.len()), - )); - } - } - } - } - BoundaryType::NotWordBoundary => (), - BoundaryType::Unknown => { - if current_tag.is_some() { - return Err(VaporettoError::invalid_argument("sentence", "")); - } - } - } - } - if let Some(tag) = current_tag.take() { - for end in 0..sentence.chars.len().min(usize::from(self.char_ngram_size)) { - features.push(TagFeature::left_char_ngram_bos( - -1, - sentence.char_substring(0, end), - )); - } - features.push(TagFeature::chars(sentence.char_substring(0, tag_right_pos))); - result.push(TagExample { features, tag }); - } - Ok(result) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::sentence::CharacterType::*; - use BoundaryFeature::*; - use BoundaryType::*; - - #[test] - fn test_example_generator_generate_one() { - let dict = Some(["東京特許許可局", "火星猫", "猫"]); - let gen = BoundaryExampleGenerator::new(3, 2, 3, 2, dict, 2).unwrap(); - - let s = Sentence::from_raw("猫").unwrap(); - let examples = gen.generate(&s).unwrap(); - - assert!(examples.is_empty()); - } - - #[test] - fn test_example_generator_generate_all() { - let dict = Some(["東京特許許可局", "火星猫", "猫"]); - let gen = BoundaryExampleGenerator::new(3, 2, 3, 2, dict, 2).unwrap(); - - let s = Sentence::from_partial_annotation("A-r-i-a|は|火-星 猫|だ").unwrap(); - let examples = gen.generate(&s).unwrap(); - - assert_eq!(7, examples.len()); - - // pos 3 "A-r" - #[rustfmt::skip] - let expected = BoundaryExample { - features: vec![ - CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "A" }), - CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "r" }), - CharacterNgram(StringNgramFeature { rel_position: 1, ngram: "i" }), - CharacterNgram(StringNgramFeature { rel_position: 2, ngram: "a" }), - CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "Ar" }), - CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "ri" }), - CharacterNgram(StringNgramFeature { rel_position: 1, ngram: "ia" }), - CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "Ari" }), - CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "ria" }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: &[Roman as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: &[Roman as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: 1, ngram: &[Roman as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: &[Roman as u8, Roman as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: &[Roman as u8, Roman as u8] }), - ], - label: NotWordBoundary, - }; - assert_eq!(expected, examples[0]); - - // pos 3 "a|は" - #[rustfmt::skip] - let expected = BoundaryExample { - features: vec![ - CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "r" }), - CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "i" }), - CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "a" }), - CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "は" }), - CharacterNgram(StringNgramFeature { rel_position: 1, ngram: "火" }), - CharacterNgram(StringNgramFeature { rel_position: 2, ngram: "星" }), - CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "ri" }), - CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "ia" }), - CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "aは" }), - CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "は火" }), - CharacterNgram(StringNgramFeature { rel_position: 1, ngram: "火星" }), - CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "ria" }), - CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "iaは" }), - CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "aは火" }), - CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "は火星" }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -2, ngram: &[Roman as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: &[Roman as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: &[Hiragana as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: 1, ngram: &[Kanji as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -2, ngram: &[Roman as u8, Roman as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: &[Roman as u8, Hiragana as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: &[Hiragana as u8, Kanji as u8] }), - ], - label: WordBoundary, - }; - assert_eq!(expected, examples[3]); - - // pos 6 "星 猫" (skipped) - - // pos 7 "猫|だ" - #[rustfmt::skip] - let expected = BoundaryExample { - features: vec![ - CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "火" }), - CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "星" }), - CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "猫" }), - CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "だ" }), - CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "火星" }), - CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "星猫" }), - CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "猫だ" }), - CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "火星猫" }), - CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "星猫だ" }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -2, ngram: &[Kanji as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: &[Kanji as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: &[Hiragana as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -2, ngram: &[Kanji as u8, Kanji as u8] }), - CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: &[Kanji as u8, Hiragana as u8] }), - DictionaryWord(DictionaryWordFeature { position: DictionaryWordPosition::Left, length: 2 }), - DictionaryWord(DictionaryWordFeature { position: DictionaryWordPosition::Left, length: 1 }), - ], - label: WordBoundary, - }; - assert_eq!(expected, examples[6]); - } - - #[test] - fn test_example_generator_generate_without_unknown() { - let dict = Some(["東京特許許可局", "火星猫", "猫"]); - let gen = BoundaryExampleGenerator::new(3, 2, 3, 2, dict, 2).unwrap(); - - let s = Sentence::from_partial_annotation("A-r-i-a|は|火-星 猫|だ").unwrap(); - let examples = gen.generate(&s).unwrap(); - - assert_eq!(7, examples.len()); - } - - #[test] - fn test_tag_example_generate_33() { - let gen = TagExampleGenerator::new(3, 3); - - let s = - Sentence::from_partial_annotation("A-r-i-a/名詞|は/助詞|火-星 猫|だ/助動詞").unwrap(); - let mut examples = gen.generate(&s).unwrap(); - - // The order of examples is unimportant. - examples - .iter_mut() - .for_each(|example| example.features.sort_unstable()); - examples.sort_unstable(); - - let mut expected = vec![ - TagExample { - features: vec![ - TagFeature::right_char_ngram(1, "iaは"), - TagFeature::right_char_ngram(1, "aは"), - TagFeature::right_char_ngram(1, "は"), - TagFeature::right_char_ngram(2, "aは火"), - TagFeature::right_char_ngram(2, "は火"), - TagFeature::right_char_ngram(2, "火"), - TagFeature::right_char_ngram(3, "は火星"), - TagFeature::right_char_ngram(3, "火星"), - TagFeature::right_char_ngram(3, "星"), - TagFeature::left_char_ngram_bos(-1, ""), - TagFeature::left_char_ngram_bos(-1, "A"), - TagFeature::left_char_ngram_bos(-1, "Ar"), - TagFeature::chars("Aria"), - ], - tag: Arc::new("名詞".to_string()), - }, - TagExample { - features: vec![ - TagFeature::right_char_ngram(1, "aは火"), - TagFeature::right_char_ngram(1, "は火"), - TagFeature::right_char_ngram(1, "火"), - TagFeature::right_char_ngram(2, "は火星"), - TagFeature::right_char_ngram(2, "火星"), - TagFeature::right_char_ngram(2, "星"), - TagFeature::right_char_ngram(3, "火星猫"), - TagFeature::right_char_ngram(3, "星猫"), - TagFeature::right_char_ngram(3, "猫"), - TagFeature::left_char_ngram(-3, "r"), - TagFeature::left_char_ngram(-3, "ri"), - TagFeature::left_char_ngram(-3, "ria"), - TagFeature::left_char_ngram(-2, "i"), - TagFeature::left_char_ngram(-2, "ia"), - TagFeature::left_char_ngram(-2, "iaは"), - TagFeature::left_char_ngram(-1, "a"), - TagFeature::left_char_ngram(-1, "aは"), - TagFeature::left_char_ngram(-1, "aは火"), - TagFeature::chars("は"), - ], - tag: Arc::new("助詞".to_string()), - }, - TagExample { - features: vec![ - TagFeature::right_char_ngram_eos(1, "猫だ"), - TagFeature::right_char_ngram_eos(1, "だ"), - TagFeature::right_char_ngram_eos(1, ""), - TagFeature::left_char_ngram(-3, "火"), - TagFeature::left_char_ngram(-3, "火星"), - TagFeature::left_char_ngram(-3, "火星猫"), - TagFeature::left_char_ngram(-2, "星"), - TagFeature::left_char_ngram(-2, "星猫"), - TagFeature::left_char_ngram(-2, "星猫だ"), - TagFeature::left_char_ngram(-1, "猫"), - TagFeature::left_char_ngram(-1, "猫だ"), - TagFeature::chars("だ"), - ], - tag: Arc::new("助動詞".to_string()), - }, - ]; - - expected - .iter_mut() - .for_each(|example| example.features.sort_unstable()); - expected.sort_unstable(); - - assert_eq!(expected, examples); - } - - #[test] - fn test_tag_example_generate_32() { - let gen = TagExampleGenerator::new(3, 2); - - let s = - Sentence::from_partial_annotation("A-r-i-a/名詞|は/助詞|火-星 猫|だ/助動詞").unwrap(); - let mut examples = gen.generate(&s).unwrap(); - - // The order of examples is unimportant. - examples - .iter_mut() - .for_each(|example| example.features.sort_unstable()); - examples.sort_unstable(); - - let mut expected = vec![ - TagExample { - features: vec![ - TagFeature::right_char_ngram(1, "iaは"), - TagFeature::right_char_ngram(1, "aは"), - TagFeature::right_char_ngram(1, "は"), - TagFeature::right_char_ngram(2, "aは火"), - TagFeature::right_char_ngram(2, "は火"), - TagFeature::right_char_ngram(2, "火"), - TagFeature::left_char_ngram_bos(-1, ""), - TagFeature::left_char_ngram_bos(-1, "A"), - TagFeature::left_char_ngram_bos(-1, "Ar"), - TagFeature::chars("Aria"), - ], - tag: Arc::new("名詞".to_string()), - }, - TagExample { - features: vec![ - TagFeature::right_char_ngram(1, "aは火"), - TagFeature::right_char_ngram(1, "は火"), - TagFeature::right_char_ngram(1, "火"), - TagFeature::right_char_ngram(2, "は火星"), - TagFeature::right_char_ngram(2, "火星"), - TagFeature::right_char_ngram(2, "星"), - TagFeature::left_char_ngram(-2, "i"), - TagFeature::left_char_ngram(-2, "ia"), - TagFeature::left_char_ngram(-2, "iaは"), - TagFeature::left_char_ngram(-1, "a"), - TagFeature::left_char_ngram(-1, "aは"), - TagFeature::left_char_ngram(-1, "aは火"), - TagFeature::chars("は"), - ], - tag: Arc::new("助詞".to_string()), - }, - TagExample { - features: vec![ - TagFeature::right_char_ngram_eos(1, "猫だ"), - TagFeature::right_char_ngram_eos(1, "だ"), - TagFeature::right_char_ngram_eos(1, ""), - TagFeature::left_char_ngram(-2, "星"), - TagFeature::left_char_ngram(-2, "星猫"), - TagFeature::left_char_ngram(-2, "星猫だ"), - TagFeature::left_char_ngram(-1, "猫"), - TagFeature::left_char_ngram(-1, "猫だ"), - TagFeature::chars("だ"), - ], - tag: Arc::new("助動詞".to_string()), - }, - ]; - - expected - .iter_mut() - .for_each(|example| example.features.sort_unstable()); - expected.sort_unstable(); - - assert_eq!(expected, examples); - } - - #[test] - fn test_tag_example_generate_23() { - let gen = TagExampleGenerator::new(2, 3); - - let s = - Sentence::from_partial_annotation("A-r-i-a/名詞|は/助詞|火-星 猫|だ/助動詞").unwrap(); - let mut examples = gen.generate(&s).unwrap(); - - // The order of examples is unimportant. - examples - .iter_mut() - .for_each(|example| example.features.sort_unstable()); - examples.sort_unstable(); - - let mut expected = vec![ - TagExample { - features: vec![ - TagFeature::right_char_ngram(1, "aは"), - TagFeature::right_char_ngram(1, "は"), - TagFeature::right_char_ngram(2, "は火"), - TagFeature::right_char_ngram(2, "火"), - TagFeature::right_char_ngram(3, "火星"), - TagFeature::right_char_ngram(3, "星"), - TagFeature::left_char_ngram_bos(-1, ""), - TagFeature::left_char_ngram_bos(-1, "A"), - TagFeature::chars("Aria"), - ], - tag: Arc::new("名詞".to_string()), - }, - TagExample { - features: vec![ - TagFeature::right_char_ngram(1, "は火"), - TagFeature::right_char_ngram(1, "火"), - TagFeature::right_char_ngram(2, "火星"), - TagFeature::right_char_ngram(2, "星"), - TagFeature::right_char_ngram(3, "星猫"), - TagFeature::right_char_ngram(3, "猫"), - TagFeature::left_char_ngram(-3, "r"), - TagFeature::left_char_ngram(-3, "ri"), - TagFeature::left_char_ngram(-2, "i"), - TagFeature::left_char_ngram(-2, "ia"), - TagFeature::left_char_ngram(-1, "a"), - TagFeature::left_char_ngram(-1, "aは"), - TagFeature::chars("は"), - ], - tag: Arc::new("助詞".to_string()), - }, - TagExample { - features: vec![ - TagFeature::right_char_ngram_eos(1, "だ"), - TagFeature::right_char_ngram_eos(1, ""), - TagFeature::left_char_ngram(-3, "火"), - TagFeature::left_char_ngram(-3, "火星"), - TagFeature::left_char_ngram(-2, "星"), - TagFeature::left_char_ngram(-2, "星猫"), - TagFeature::left_char_ngram(-1, "猫"), - TagFeature::left_char_ngram(-1, "猫だ"), - TagFeature::chars("だ"), - ], - tag: Arc::new("助動詞".to_string()), - }, - ]; - - expected - .iter_mut() - .for_each(|example| example.features.sort_unstable()); - expected.sort_unstable(); - - assert_eq!(expected, examples); - } - - #[test] - fn test_tag_example_generate_check_sentence_boundary() { - let gen = TagExampleGenerator::new(3, 3); - - let s = Sentence::from_tokenized("僕/代名詞 は/助詞 人間/名詞").unwrap(); - let mut examples = gen.generate(&s).unwrap(); - - // The order of examples is unimportant. - examples - .iter_mut() - .for_each(|example| example.features.sort_unstable()); - examples.sort_unstable(); - - let mut expected = vec![ - TagExample { - features: vec![ - TagFeature::right_char_ngram(1, "僕は"), - TagFeature::right_char_ngram(1, "は"), - TagFeature::right_char_ngram(2, "僕は人"), - TagFeature::right_char_ngram(2, "は人"), - TagFeature::right_char_ngram(2, "人"), - TagFeature::right_char_ngram(3, "は人間"), - TagFeature::right_char_ngram(3, "人間"), - TagFeature::right_char_ngram(3, "間"), - TagFeature::left_char_ngram_bos(-1, ""), - TagFeature::left_char_ngram_bos(-1, "僕"), - TagFeature::left_char_ngram_bos(-1, "僕は"), - TagFeature::chars("僕"), - ], - tag: Arc::new("代名詞".to_string()), - }, - TagExample { - features: vec![ - TagFeature::right_char_ngram(1, "僕は人"), - TagFeature::right_char_ngram(1, "は人"), - TagFeature::right_char_ngram(1, "人"), - TagFeature::right_char_ngram(2, "は人間"), - TagFeature::right_char_ngram(2, "人間"), - TagFeature::right_char_ngram(2, "間"), - TagFeature::right_char_ngram_eos(3, "人間"), - TagFeature::right_char_ngram_eos(3, "間"), - TagFeature::right_char_ngram_eos(3, ""), - TagFeature::left_char_ngram_bos(-2, "僕は"), - TagFeature::left_char_ngram_bos(-2, "僕"), - TagFeature::left_char_ngram_bos(-2, ""), - TagFeature::left_char_ngram(-1, "僕は人"), - TagFeature::left_char_ngram(-1, "僕は"), - TagFeature::left_char_ngram(-1, "僕"), - TagFeature::chars("は"), - ], - tag: Arc::new("助詞".to_string()), - }, - TagExample { - features: vec![ - TagFeature::right_char_ngram_eos(1, "人間"), - TagFeature::right_char_ngram_eos(1, "間"), - TagFeature::right_char_ngram_eos(1, ""), - TagFeature::left_char_ngram_bos(-3, "僕は"), - TagFeature::left_char_ngram_bos(-3, "僕"), - TagFeature::left_char_ngram_bos(-3, ""), - TagFeature::left_char_ngram(-2, "僕は人"), - TagFeature::left_char_ngram(-2, "僕は"), - TagFeature::left_char_ngram(-2, "僕"), - TagFeature::left_char_ngram(-1, "は人間"), - TagFeature::left_char_ngram(-1, "は人"), - TagFeature::left_char_ngram(-1, "は"), - TagFeature::chars("人間"), - ], - tag: Arc::new("名詞".to_string()), - }, - ]; - - expected - .iter_mut() - .for_each(|example| example.features.sort_unstable()); - expected.sort_unstable(); - - assert_eq!(expected, examples); - } -} diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 61a3af76..08c142bd 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -5,7 +5,6 @@ use crate::dict_model::{DictModel, DictWeight, WordWeightRecord}; use crate::errors::{Result, VaporettoError}; use crate::model::Model; use crate::ngram_model::{NgramData, NgramModel}; -use crate::tag_model::TagModel; use crate::utils; struct KyteaConfig { @@ -495,13 +494,13 @@ impl TryFrom for Model { } Ok(Self::new( - NgramModel { data: char_ngrams }, - NgramModel { data: type_ngrams }, + NgramModel(char_ngrams), + NgramModel(type_ngrams), DictModel::new(dict), bias, config.char_w, config.type_w, - TagModel::default(), + vec![], )) } } diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index 6357fdf1..07bd4513 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -2,32 +2,58 @@ //! //! Vaporetto is a fast and lightweight pointwise prediction based tokenizer. //! -//! ## Examples //! -//! ```no_run -//! use std::fs::File; -//! use std::io::Read; -//! -//! use vaporetto::{Model, Predictor, Sentence}; -//! -//! let mut f = File::open("model.bin").unwrap(); -//! let mut model_data = vec![]; -//! f.read_to_end(&mut model_data).unwrap(); -//! let (model, _) = Model::read_slice(&model_data).unwrap(); -//! let predictor = Predictor::new(model, false).unwrap(); -//! -//! let s = Sentence::from_raw("火星猫の生態").unwrap(); -//! let s = predictor.predict(s); +#![cfg_attr( + all(feature = "std", feature = "tag-prediction"), + doc = " +## Examples + +``` +use std::fs::File; + +use vaporetto::{Model, Predictor, Sentence}; + +let f = File::open(\"../resources/model.bin\").unwrap(); +let model = Model::read(f).unwrap(); +let predictor = Predictor::new(model, true).unwrap(); + +let mut buf = String::new(); + +let mut s = Sentence::default(); + +s.update_raw(\"まぁ社長は火星猫だ\").unwrap(); +predictor.predict(&mut s); +s.fill_tags(); +s.write_tokenized_text(&mut buf); +assert_eq!( + \"まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ 火星/名詞/カセー 猫/名詞/ネコ だ/助動詞/ダ\", + buf, +); + +s.update_raw(\"まぁ良いだろう\").unwrap(); +predictor.predict(&mut s); +s.fill_tags(); +s.write_tokenized_text(&mut buf); +assert_eq!( + \"まぁ/副詞/マー 良い/形容詞/ヨイ だろう/助動詞/ダロー\", + buf, +); +``` +" +)] //! -//! println!("{:?}", s.to_tokenized_vec().unwrap()); -//! ``` +//! Tag prediction requires **crate feature** `tag-prediction`. //! //! Training requires **crate feature** `train`. For more details, see [`Trainer`]. +#![deny(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(feature = "portable-simd", feature(portable_simd))] +#[cfg(not(feature = "alloc"))] +compile_error!("`alloc` feature is currently required to build this crate"); + #[macro_use] extern crate alloc; @@ -37,14 +63,11 @@ mod model; mod ngram_model; mod predictor; mod sentence; -mod tag_model; mod type_scorer; mod utils; pub mod errors; -#[cfg(feature = "train")] -mod feature; #[cfg(feature = "train")] mod tag_trainer; #[cfg(feature = "train")] @@ -56,7 +79,7 @@ mod kytea_model; pub use dict_model::WordWeightRecord; pub use model::Model; pub use predictor::Predictor; -pub use sentence::{BoundaryType, CharacterType, Sentence, Token}; +pub use sentence::{CharacterBoundary, CharacterType, Sentence, Token, TokenIterator}; #[cfg(feature = "train")] pub use trainer::{SolverType, Trainer}; diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index 25ecc7d1..e3bd9664 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -8,19 +8,48 @@ use bincode::{Decode, Encode}; use crate::dict_model::{DictModel, WordWeightRecord}; use crate::errors::{Result, VaporettoError}; -use crate::ngram_model::NgramModel; -use crate::tag_model::TagModel; +use crate::ngram_model::{NgramModel, TagNgramModel}; use crate::utils::VecWriter; /// Magic number. -const MODEL_MAGIC: &[u8] = b"VaporettoTokenizer 0.4.0\n"; +const MODEL_MAGIC: &[u8] = b"VaporettoTokenizer 0.5.0\n"; -/// Model data. -pub struct Model { - pub(crate) data: ModelData, +// For each token, a model is trained for every tag independently, but the scores of all tags are +// calculated in parallel during prediction. +// Thus, the score array is a concatenation of all classes of all tags. +// +// For example, the following token has 3 POS tags and 3 pronunciation tags, so the score array +// contains 6 items. The predictor picks the tag with the largest score. +// +// token: "君" +// tags: [["名詞", "代名詞", "接尾辞"], ["クン", "キミ", "ギミ"]] +// scores: [ 176, 3647, 39, 518, 9346, 126 ] +// +// results: ["代名詞", "キミ"] +// +// If there is only one tag candidate, the model is not trained. +// In the following example, the predictor determines the first tag without prediction, so the +// score array only contains scores for the second tag. +// +// token: "犬" +// tags: [["名詞"], ["イヌ", "ケン"]] +// scores: [ 475, 1563 ] +// +// results: ["名詞", "ケン"] +#[derive(Debug, Decode, Encode)] +pub struct TagModel { + pub(crate) token: String, + pub(crate) tags: Vec>, + pub(crate) char_ngram_model: TagNgramModel, + pub(crate) type_ngram_model: TagNgramModel>, + pub(crate) bias: Vec, } -#[derive(Decode, Encode)] +/// Model data. +#[derive(Debug)] +pub struct Model(pub(crate) ModelData); + +#[derive(Debug, Decode, Encode)] pub struct ModelData { pub(crate) char_ngram_model: NgramModel, pub(crate) type_ngram_model: NgramModel>, @@ -28,7 +57,8 @@ pub struct ModelData { pub(crate) bias: i32, pub(crate) char_window_size: u8, pub(crate) type_window_size: u8, - pub(crate) tag_model: TagModel, + // Instead of using Map, we use Vec to increase compression ratio and performance. + pub(crate) tag_models: Vec, } impl Model { @@ -40,19 +70,17 @@ impl Model { bias: i32, char_window_size: u8, type_window_size: u8, - tag_model: TagModel, + tag_models: Vec, ) -> Self { - Self { - data: ModelData { - char_ngram_model, - type_ngram_model, - dict_model, - bias, - char_window_size, - type_window_size, - tag_model, - }, - } + Self(ModelData { + char_ngram_model, + type_ngram_model, + dict_model, + bias, + char_window_size, + type_window_size, + tag_models, + }) } /// Exports the model data into a [`Vec`]. @@ -63,16 +91,12 @@ impl Model { pub fn to_vec(&self) -> Result> { let mut wtr = VecWriter(MODEL_MAGIC.to_vec()); let config = bincode::config::standard(); - bincode::encode_into_writer(&self.data, &mut wtr, config)?; + bincode::encode_into_writer(&self.0, &mut wtr, config)?; Ok(wtr.0) } /// Exports the model data. /// - /// # Arguments - /// - /// * `wtr` - Byte-oriented sink object. - /// /// # Errors /// /// When bincode generates an error, it will be returned as is. @@ -83,19 +107,11 @@ impl Model { { wtr.write_all(MODEL_MAGIC)?; let config = bincode::config::standard(); - bincode::encode_into_std_write(&self.data, &mut wtr, config)?; + bincode::encode_into_std_write(&self.0, &mut wtr, config)?; Ok(()) } - /// Creates a model from a slice. - /// - /// # Arguments - /// - /// * `slice` - A data source. - /// - /// # Returns - /// - /// A tuple of the model data read from `slice` and the remaining slice. + /// Creates a model from a slice and returns a tuple of the model and the remaining slice. /// /// # Errors /// @@ -106,19 +122,11 @@ impl Model { } let config = bincode::config::standard(); let (data, size) = bincode::decode_from_slice(&slice[MODEL_MAGIC.len()..], config)?; - Ok((Self { data }, &slice[MODEL_MAGIC.len() + size..])) + Ok((Self(data), &slice[MODEL_MAGIC.len() + size..])) } /// Creates a model from a reader. /// - /// # Arguments - /// - /// * `rdr` - A data source. - /// - /// # Returns - /// - /// A model data read from `rdr`. - /// /// # Errors /// /// When bincode generates an error, it will be returned as is. @@ -133,16 +141,16 @@ impl Model { return Err(VaporettoError::invalid_model("model version mismatch")); } let config = bincode::config::standard(); - Ok(Self { - data: bincode::decode_from_std_read(&mut rdr, config)?, - }) + Ok(Self(bincode::decode_from_std_read(&mut rdr, config)?)) } + /// Returns the slice of dictionary words. pub fn dictionary(&self) -> &[WordWeightRecord] { - self.data.dict_model.dictionary() + self.0.dict_model.dictionary() } + /// Replaces the dictionary with the given data. pub fn replace_dictionary(&mut self, dict: Vec) { - self.data.dict_model = DictModel::new(dict); + self.0.dict_model = DictModel::new(dict); } } diff --git a/vaporetto/src/ngram_model.rs b/vaporetto/src/ngram_model.rs index dab0efff..6a00e34d 100644 --- a/vaporetto/src/ngram_model.rs +++ b/vaporetto/src/ngram_model.rs @@ -2,13 +2,26 @@ use alloc::vec::Vec; use bincode::{Decode, Encode}; -#[derive(Clone, Decode, Encode)] +#[derive(Clone, Debug, Decode, Encode)] pub struct NgramData { pub(crate) ngram: T, pub(crate) weights: Vec, } -#[derive(Default, Decode, Encode)] -pub struct NgramModel { - pub(crate) data: Vec>, +#[derive(Default, Debug, Decode, Encode)] +pub struct NgramModel(pub Vec>); + +#[derive(Clone, Debug, Decode, Encode)] +pub struct TagWeight { + pub(crate) rel_position: u8, + pub(crate) weights: Vec, +} + +#[derive(Clone, Debug, Decode, Encode)] +pub struct TagNgramData { + pub(crate) ngram: T, + pub(crate) weights: Vec, } + +#[derive(Default, Debug, Decode, Encode)] +pub struct TagNgramModel(pub Vec>); diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index c97910de..1cf93cf3 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -1,334 +1,624 @@ -use core::mem; +use core::ops::AddAssign; + +#[cfg(all(feature = "fix-weight-length", feature = "portable-simd"))] +use core::simd::Simd; use alloc::vec::Vec; #[cfg(feature = "tag-prediction")] -use core::cmp::Ordering; - +use alloc::borrow::Cow; #[cfg(feature = "tag-prediction")] use alloc::string::String; -#[cfg(feature = "tag-prediction")] -use alloc::sync::Arc; -use bincode::{BorrowDecode, Encode}; +use bincode::{ + de::{BorrowDecoder, Decoder}, + enc::Encoder, + error::{DecodeError, EncodeError}, + BorrowDecode, Decode, Encode, +}; + +#[cfg(feature = "tag-prediction")] +use hashbrown::HashMap; -use crate::char_scorer::{self, CharScorer}; +use crate::char_scorer::CharScorer; use crate::errors::Result; use crate::model::Model; -use crate::sentence::{BoundaryType, Sentence}; +use crate::sentence::{CharacterBoundary, Sentence}; use crate::type_scorer::TypeScorer; #[cfg(feature = "tag-prediction")] -use crate::char_scorer::CharScorerWithTags; +use crate::utils::SerializableHashMap; -#[derive(BorrowDecode, Encode)] -enum CharScorerWrapper { - Boundary(CharScorer), +pub const WEIGHT_FIXED_LEN: usize = 8; - #[cfg(feature = "tag-prediction")] - BoundaryAndTags(CharScorerWithTags), -} +#[cfg(all(feature = "fix-weight-length", not(feature = "portable-simd")))] +pub type I32Simd = [i32; WEIGHT_FIXED_LEN]; +#[cfg(all(feature = "fix-weight-length", feature = "portable-simd"))] +pub type I32Simd = Simd; + +#[derive(Clone, Debug)] +pub enum WeightVector { + Variable(Vec), -/// Predictor. -pub struct Predictor { - data: PredictorData, + #[cfg(feature = "fix-weight-length")] + Fixed(I32Simd), } -/// WARNING: The decode feature is inherently unsafe. Do not publish this feature outside this -/// crate. -#[derive(BorrowDecode, Encode)] -struct PredictorData { - bias: i32, +impl Decode for WeightVector { + fn decode(decoder: &mut D) -> Result { + let weight: Vec = Decode::decode(decoder)?; + Ok(Self::from(weight)) + } +} - char_scorer: CharScorerWrapper, - type_scorer: TypeScorer, +impl Encode for WeightVector { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + match self { + WeightVector::Variable(w) => { + Encode::encode(&w, encoder)?; + } - padding: u8, + #[cfg(feature = "fix-weight-length")] + WeightVector::Fixed(w) => { + #[cfg(feature = "portable-simd")] + let w = w.as_array(); - #[cfg(feature = "tag-prediction")] - tag_names: Vec>, - #[cfg(feature = "tag-prediction")] - tag_bias: Vec, + Encode::encode(&crate::utils::trim_end_zeros(w).to_vec(), encoder)?; + } + } + Ok(()) + } } -impl Predictor { - /// Creates a new predictor. - /// - /// # Arguments - /// - /// * `model` - A model data. - /// * `predict_tags` - If you want to predict tags, set to true. - /// - /// # Returns - /// - /// A new predictor. - pub fn new( - model: Model, - #[cfg(feature = "tag-prediction")] predict_tags: bool, - #[cfg(not(feature = "tag-prediction"))] _predict_tags: bool, - ) -> Result { - #[cfg(feature = "tag-prediction")] - let mut tag_names = vec![]; - #[cfg(feature = "tag-prediction")] - let mut tag_bias = vec![]; +#[cfg(feature = "tag-prediction")] +impl WeightVector { + pub fn add_scores(&self, ys: &mut [i32]) { + match self { + WeightVector::Variable(w) => { + for (y, x) in ys.iter_mut().zip(w) { + *y += *x; + } + } - #[cfg(feature = "tag-prediction")] - let char_scorer = if predict_tags { - for cls in model.data.tag_model.class_info { - tag_names.push(Arc::new(cls.name)); - tag_bias.push(cls.bias); + #[cfg(feature = "fix-weight-length")] + WeightVector::Fixed(w) => { + #[cfg(not(feature = "portable-simd"))] + for (y, x) in ys[..WEIGHT_FIXED_LEN].iter_mut().zip(w) { + *y += *x + } + + #[cfg(feature = "portable-simd")] + { + let ys = &mut ys[..WEIGHT_FIXED_LEN]; + let mut y = I32Simd::from_slice(ys); + y += w; + ys.copy_from_slice(y.as_array()); + } } - CharScorerWrapper::BoundaryAndTags(CharScorerWithTags::new( - model.data.char_ngram_model, - model.data.char_window_size, - model.data.dict_model, - tag_names.len(), - model.data.tag_model.left_char_model, - model.data.tag_model.right_char_model, - model.data.tag_model.self_char_model, - )?) - } else { - CharScorerWrapper::Boundary(CharScorer::new( - model.data.char_ngram_model, - model.data.char_window_size, - model.data.dict_model, - )?) - }; + } + } - #[cfg(not(feature = "tag-prediction"))] - let char_scorer = CharScorerWrapper::Boundary(CharScorer::new( - model.data.char_ngram_model, - model.data.char_window_size, - model.data.dict_model, - )?); + pub fn len(&self) -> usize { + match self { + WeightVector::Variable(w) => w.len(), - let type_scorer = - TypeScorer::new(model.data.type_ngram_model, model.data.type_window_size)?; + #[cfg(feature = "fix-weight-length")] + WeightVector::Fixed(_) => WEIGHT_FIXED_LEN, + } + } +} - Ok(Self { - data: PredictorData { - bias: model.data.bias, +impl From> for WeightVector { + fn from(src: Vec) -> Self { + match src.len() { + #[cfg(feature = "fix-weight-length")] + 0..=WEIGHT_FIXED_LEN => { + let mut weight = [0; WEIGHT_FIXED_LEN]; + weight[..src.len()].copy_from_slice(&src); - char_scorer, - type_scorer, + #[cfg(feature = "portable-simd")] + let weight = I32Simd::from(weight); - padding: model.data.char_window_size.max(model.data.type_window_size), + Self::Fixed(weight) + } - #[cfg(feature = "tag-prediction")] - tag_names, - #[cfg(feature = "tag-prediction")] - tag_bias, - }, - }) + _ => Self::Variable(src), + } } +} - /// Serializes the predictor into a Vec. - pub fn serialize_to_vec(&self) -> Result> { - let config = bincode::config::standard(); - let result = bincode::encode_to_vec(&self.data, config)?; - Ok(result) +#[derive(Clone, Debug, Default, Eq, PartialEq, Decode, Encode)] +pub struct PositionalWeight { + offset: i16, + weight: W, +} + +impl PositionalWeight> { + pub fn new(offset: i16, weight: Vec) -> Self { + Self { offset, weight } } +} - /// Deserializes a predictor from a given slice. - /// - /// # Arguments - /// - /// * `data` - A source slice. - /// - /// # Returns - /// - /// A tuple of a predictor and a slice not used for the deserialization. - /// - /// # Safety - /// - /// The given data must be a correct predictor exported by [`Predictor::serialize_to_vec()`] - /// function. - pub unsafe fn deserialize_from_slice_unchecked(data: &[u8]) -> Result<(Self, &[u8])> { - let config = bincode::config::standard(); - // Deserialization is unsafe because the automaton will not be verified. - let (predictor_data, size) = bincode::decode_from_slice(data, config)?; - Ok(( - Self { - data: predictor_data, - }, - &data[size..], - )) +impl AddAssign<&Self> for PositionalWeight> { + fn add_assign(&mut self, other: &Self) { + let new_offset = self.offset.min(other.offset); + let shift = usize::try_from(self.offset - new_offset).unwrap(); + let new_size = (shift + self.weight.len()) + .max(usize::try_from(other.offset - new_offset).unwrap() + other.weight.len()); + self.weight.resize(new_size, 0); + self.weight.rotate_right(shift); + for (y, x) in self.weight[usize::try_from(other.offset - new_offset).unwrap()..] + .iter_mut() + .zip(&other.weight) + { + *y += *x; + } + self.offset = new_offset; + } +} + +impl From>> for PositionalWeight { + fn from(src: PositionalWeight>) -> Self { + Self { + offset: src.offset, + weight: src.weight.into(), + } } +} - fn predict_impl(&self, mut sentence: Sentence) -> Sentence { - let ys_size = - sentence.boundaries.len() + usize::from(self.data.padding) + char_scorer::SIMD_SIZE - 1; - let mut ys = mem::take(&mut sentence.boundary_scores); - ys.clear(); - ys.resize(ys_size, self.data.bias); - match &self.data.char_scorer { - CharScorerWrapper::Boundary(char_scorer) => { - char_scorer.add_scores(&sentence, self.data.padding, &mut ys); +impl PositionalWeight { + #[inline(always)] + pub fn add_score(&self, end: isize, ys: &mut [i32]) { + let pos = end + isize::from(self.offset); + match &self.weight { + WeightVector::Variable(w) => { + if pos >= 0 { + for (y, x) in ys[pos as usize..].iter_mut().zip(w) { + *y += *x; + } + } else if let Some(xs) = w.get((-pos) as usize..) { + for (y, x) in ys.iter_mut().zip(xs) { + *y += *x; + } + } } - #[cfg(feature = "tag-prediction")] - CharScorerWrapper::BoundaryAndTags(char_scorer) => { - let mut tag_ys = mem::take(&mut sentence.tag_scores); - tag_ys.init(sentence.chars.len(), self.data.tag_names.len()); - char_scorer.add_scores(&sentence, self.data.padding, &mut ys, &mut tag_ys); - sentence.tag_scores = tag_ys; + #[cfg(feature = "fix-weight-length")] + WeightVector::Fixed(w) => { + #[cfg(not(feature = "portable-simd"))] + for (y, x) in ys[pos as usize..pos as usize + WEIGHT_FIXED_LEN] + .iter_mut() + .zip(w) + { + *y += *x + } + + #[cfg(feature = "portable-simd")] + { + let ys = &mut ys[pos as usize..pos as usize + WEIGHT_FIXED_LEN]; + let mut y = I32Simd::from_slice(ys); + y += w; + ys.copy_from_slice(y.as_array()); + } } } - self.data - .type_scorer - .add_scores(&sentence, self.data.padding, &mut ys); - for (&y, b) in ys[self.data.padding.into()..] - .iter() - .zip(sentence.boundaries.iter_mut()) - { - *b = if y >= 0 { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }; + } +} + +#[cfg(feature = "tag-prediction")] +#[derive(Debug, Default, Eq, PartialEq)] +pub struct PositionalWeightWithTag { + pub weight: Option>>, + pub tag_info: HashMap<(usize, u8), Vec>, +} + +#[cfg(feature = "tag-prediction")] +impl PositionalWeightWithTag { + pub fn with_boundary(offset: i16, weight: Vec) -> Self { + Self { + weight: Some(PositionalWeight::new(offset, weight)), + tag_info: HashMap::new(), } - sentence.boundary_scores = ys; - sentence } - /// Predicts word boundaries. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict(&self, sentence: Sentence) -> Sentence { - let mut sentence = self.predict_impl(sentence); - sentence.boundary_scores.clear(); - sentence + pub fn with_tag(token_id: usize, rel_position: u8, tag_weight: Vec) -> Self { + let mut tag_info = HashMap::new(); + tag_info.insert((token_id, rel_position), tag_weight); + Self { + weight: None, + tag_info, + } } +} - /// Predicts word boundaries. This function inserts scores. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict_with_score(&self, sentence: Sentence) -> Sentence { - let mut sentence = self.predict_impl(sentence); - sentence - .boundary_scores - .rotate_left(self.data.padding.into()); - sentence.boundary_scores.truncate(sentence.boundaries.len()); - sentence +#[cfg(feature = "tag-prediction")] +impl AddAssign<&Self> for PositionalWeightWithTag { + fn add_assign(&mut self, other: &Self) { + if let Some(y) = self.weight.as_mut() { + if let Some(x) = other.weight.as_ref() { + *y += x; + } + } else { + self.weight = other.weight.clone(); + } + for (k, v) in &other.tag_info { + self.tag_info + .entry(*k) + .and_modify(|w| { + for (y, x) in w.iter_mut().zip(v) { + *y += *x; + } + }) + .or_insert_with(|| v.clone()); + } + } +} + +#[cfg(feature = "tag-prediction")] +#[derive(Decode, Encode)] +struct TagPredictor { + tags: Vec>, + bias: WeightVector, +} + +#[cfg(feature = "tag-prediction")] +impl TagPredictor { + pub fn new(tags: Vec>, bias: Vec) -> Self { + Self { + tags, + bias: bias.into(), + } + } + + #[inline] + pub const fn bias(&self) -> &WeightVector { + &self.bias + } + + #[inline] + pub fn predict<'a>(&'a self, scores: &[i32], tags: &mut [Option>]) { + let mut offset = 0; + for (tag_cands, tag) in self.tags.iter().zip(tags) { + if tag_cands.len() >= 2 { + let mut idx = 0; + let mut max_score = i32::MIN; + for (i, &s) in scores[offset..offset + tag_cands.len()].iter().enumerate() { + if s > max_score { + idx = i; + max_score = s; + } + } + tag.replace(Cow::Borrowed(&tag_cands[idx])); + offset += tag_cands.len(); + } else { + *tag = tag_cands.first().map(|t| Cow::Borrowed(t.as_str())); + } + } } +} + +pub struct PredictorData { + char_scorer: Option, + type_scorer: Option, + bias: i32, #[cfg(feature = "tag-prediction")] - fn best_tag(&self, scores: &[i32]) -> Arc { - Arc::clone( - scores - .iter() - .zip(&self.data.tag_names) - .max_by_key(|(&x, _)| x) - .unwrap() - .1, - ) + tag_predictor: Option>, + #[cfg(feature = "tag-prediction")] + n_tags: usize, +} + +impl<'de> BorrowDecode<'de> for PredictorData { + /// WARNING: This function is inherently unsafe. Do not publish this function outside this + /// crate. + fn borrow_decode>(decoder: &mut D) -> Result { + let config = bincode::config::standard(); + let char_scorer_data: Option<&[u8]> = BorrowDecode::borrow_decode(decoder)?; + let char_scorer = if let Some(data) = char_scorer_data { + Some(bincode::decode_from_slice(data, config)?.0) + } else { + None + }; + let type_scorer_data: Option<&[u8]> = BorrowDecode::borrow_decode(decoder)?; + let type_scorer = if let Some(data) = type_scorer_data { + Some(bincode::decode_from_slice(data, config)?.0) + } else { + None + }; + let bias = Decode::decode(decoder)?; + #[cfg(feature = "tag-prediction")] + let tag_predictor = Decode::decode(decoder)?; + #[cfg(feature = "tag-prediction")] + let n_tags = Decode::decode(decoder)?; + Ok(Self { + char_scorer, + type_scorer, + bias, + #[cfg(feature = "tag-prediction")] + tag_predictor, + #[cfg(feature = "tag-prediction")] + n_tags, + }) } +} - /// Fills tags using calculated scores. - /// - /// Tags are predicted using token boundaries, so you have to apply boundary post-processors - /// before filling tags. +impl Encode for PredictorData { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let config = bincode::config::standard(); + let char_scorer_data = if let Some(char_scorer) = self.char_scorer.as_ref() { + Some(bincode::encode_to_vec(char_scorer, config)?) + } else { + None + }; + Encode::encode(&char_scorer_data, encoder)?; + let type_scorer_data = if let Some(type_scorer) = self.type_scorer.as_ref() { + Some(bincode::encode_to_vec(type_scorer, config)?) + } else { + None + }; + Encode::encode(&type_scorer_data, encoder)?; + Encode::encode(&self.bias, encoder)?; + #[cfg(feature = "tag-prediction")] + Encode::encode(&self.tag_predictor, encoder)?; + #[cfg(feature = "tag-prediction")] + Encode::encode(&self.n_tags, encoder)?; + Ok(()) + } +} + +/// Predictor created from the model. +/// +#[cfg_attr( + feature = "std", + doc = " +# Example 1: without tag prediction + +``` +use std::fs::File; + +use vaporetto::{Model, Predictor, Sentence}; + +let f = File::open(\"../resources/model.bin\").unwrap(); +let model = Model::read(f).unwrap(); +let predictor = Predictor::new(model, false).unwrap(); + +let mut s = Sentence::from_raw(\"まぁ社長は火星猫だ\").unwrap(); +predictor.predict(&mut s); +// s.fill_tags(); will panic! + +let mut buf = String::new(); +s.write_tokenized_text(&mut buf); +assert_eq!( + \"まぁ 社長 は 火星 猫 だ\", + buf, +); +``` +" +)] +#[cfg_attr( + all(feature = "std", feature = "tag-prediction"), + doc = " +# Example 2: with tag prediction + +Tag prediction requires **crate feature** `tag-prediction`. +``` +use std::fs::File; + +use vaporetto::{Model, Predictor, Sentence}; + +let mut f = File::open(\"../resources/model.bin\").unwrap(); +let model = Model::read(f).unwrap(); +let predictor = Predictor::new(model, true).unwrap(); + +let mut s = Sentence::from_raw(\"まぁ社長は火星猫だ\").unwrap(); +predictor.predict(&mut s); +s.fill_tags(); + +let mut buf = String::new(); +s.write_tokenized_text(&mut buf); +assert_eq!( + \"まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ 火星/名詞/カセー 猫/名詞/ネコ だ/助動詞/ダ\", + buf, +); +``` +" +)] +pub struct Predictor(PredictorData); + +impl Predictor { + /// Creates a new predictor from the model. /// /// # Arguments /// - /// * `sentence` - A sentence. + /// * `model` - A model data. + /// * `predict_tags` - If you want to predict tags, set to true. /// - /// # Returns + /// # Errors /// - /// A sentence with tag information. When the predictor is instantiated with - /// `predict_tag = false`, the sentence is returned without any modification. - #[cfg(feature = "tag-prediction")] - #[cfg_attr(docsrs, doc(cfg(feature = "tag-prediction")))] - pub fn fill_tags(&self, mut sentence: Sentence) -> Sentence { - if self.data.tag_names.is_empty() { - return sentence; + /// Returns an error variant when the model is invalid. + pub fn new(model: Model, predict_tags: bool) -> Result { + #[cfg(feature = "tag-prediction")] + let mut tag_char_ngram_model = vec![]; + #[cfg(feature = "tag-prediction")] + let mut tag_type_ngram_model = vec![]; + #[cfg(feature = "tag-prediction")] + let mut n_tags = 0; + + #[cfg(not(feature = "tag-prediction"))] + if predict_tags { + panic!("tag prediction is unsupported"); } - if sentence.tags.is_empty() { - sentence.tags.resize(sentence.chars().len(), None); + #[cfg(feature = "tag-prediction")] + let tag_predictor = predict_tags.then(|| { + let mut tag_predictor = HashMap::new(); + for (i, tag_model) in model.0.tag_models.into_iter().enumerate() { + n_tags = n_tags.max(tag_model.tags.len()); + // token does not duplicate in the model. + tag_predictor.insert( + tag_model.token, + ( + u32::try_from(i).unwrap(), + TagPredictor::new(tag_model.tags, tag_model.bias), + ), + ); + tag_char_ngram_model.push(tag_model.char_ngram_model); + tag_type_ngram_model.push(tag_model.type_ngram_model); + } + SerializableHashMap(tag_predictor) + }); + + let char_scorer = CharScorer::new( + model.0.char_ngram_model, + model.0.dict_model, + model.0.char_window_size, + #[cfg(feature = "tag-prediction")] + tag_char_ngram_model, + )?; + let type_scorer = TypeScorer::new( + model.0.type_ngram_model, + model.0.type_window_size, + #[cfg(feature = "tag-prediction")] + tag_type_ngram_model, + )?; + Ok(Self(PredictorData { + char_scorer, + type_scorer, + bias: model.0.bias, + + #[cfg(feature = "tag-prediction")] + tag_predictor, + #[cfg(feature = "tag-prediction")] + n_tags, + })) + } + + /// Predicts word boundaries of the given sentence. + /// If necessary, this function also prepares for predicting tags. + pub fn predict<'a, 'b>(&'b self, sentence: &mut Sentence<'a, 'b>) { + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, self.0.bias); + if let Some(scorer) = self.0.char_scorer.as_ref() { + scorer.add_scores(sentence); } - let n_tags = self.data.tag_names.len(); - let mut tag_score = self.data.tag_bias.clone(); - let mut left_scores_iter = sentence.tag_scores.left_scores.chunks(n_tags); - for (t, l) in tag_score.iter_mut().zip(left_scores_iter.next().unwrap()) { - *t += l; + if let Some(scorer) = self.0.type_scorer.as_ref() { + scorer.add_scores(sentence); } - let mut right_scores_iter = sentence.tag_scores.right_scores.chunks(n_tags); - let mut last_boundary_idx = 0; - for (i, ((((b, left_scores), right_scores), self_scores), tag)) in sentence + for (b, s) in sentence .boundaries - .iter() - .zip(left_scores_iter) - .zip(&mut right_scores_iter) - .zip(&sentence.tag_scores.self_scores) - .zip(&mut sentence.tags) - .enumerate() + .iter_mut() + .zip(&sentence.boundary_scores[sentence.score_padding..]) { - if *b == BoundaryType::WordBoundary { - for (t, r) in tag_score.iter_mut().zip(right_scores) { - *t += *r; - } - if let Some(self_weights) = self_scores.as_ref() { - let diff = -i16::try_from(i + 1 - last_boundary_idx).unwrap_or(0); - for self_weight in self_weights.iter() { - match self_weight.start_rel_position.cmp(&diff) { - Ordering::Greater => continue, - Ordering::Equal => { - for (t, s) in tag_score.iter_mut().zip(&self_weight.weight) { - *t += *s; - } + if *s > 0 { + *b = CharacterBoundary::WordBoundary; + } else { + *b = CharacterBoundary::NotWordBoundary; + } + } + sentence.set_predictor(self); + } + + #[cfg(feature = "tag-prediction")] + pub(crate) fn predict_tags<'a, 'b>(&'b self, sentence: &mut Sentence<'a, 'b>) { + let tag_predictor = self + .0 + .tag_predictor + .as_ref() + .expect("this predictor is created with predict_tags = false"); + + if self.0.n_tags == 0 { + return; + } + let mut scores = vec![]; + let mut range_start = Some(0); + sentence.n_tags = self.0.n_tags; + sentence.tags.clear(); + sentence.tags.resize(sentence.len() * self.0.n_tags, None); + for (i, &b) in sentence.boundaries.iter().enumerate() { + if b == CharacterBoundary::Unknown { + range_start.take(); + } else if b == CharacterBoundary::WordBoundary { + if let Some(&range_start) = range_start.as_ref() { + let token = sentence.text_substring(range_start, i + 1); + if let Some((token_id, tag_predictor)) = tag_predictor.get(token) { + scores.clear(); + scores.resize(tag_predictor.bias().len(), 0); + tag_predictor.bias().add_scores(&mut scores); + if let Some(scorer) = self.0.char_scorer.as_ref() { + debug_assert!(i < sentence.char_pma_states.len()); + // token_id is always smaller than tag_weight.len() because + // tag_predictor is created to contain such values in the new() + // function. + unsafe { + scorer.add_tag_scores(*token_id, i, sentence, &mut scores); } - Ordering::Less => (), } - break; + if let Some(scorer) = self.0.type_scorer.as_ref() { + debug_assert!(i < sentence.type_pma_states.len()); + // token_id is always smaller than tag_weight.len() because + // tag_predictor is created to contain such values in the new() + // function. + unsafe { + scorer.add_tag_scores(*token_id, i, sentence, &mut scores); + } + } + tag_predictor.predict( + &scores, + &mut sentence.tags[i * self.0.n_tags..(i + 1) * self.0.n_tags], + ); } } - tag.replace(self.best_tag(&tag_score)); - for (t, (l, b)) in tag_score - .iter_mut() - .zip(left_scores.iter().zip(&self.data.tag_bias)) - { - *t = *l + *b; - } - last_boundary_idx = i + 1; + range_start.replace(i + 1); } } - for (t, r) in tag_score.iter_mut().zip(right_scores_iter.next().unwrap()) { - *t += r; - } - if let Some(self_weights) = sentence.tag_scores.self_scores.last().unwrap().as_ref() { - let diff = -i16::try_from(sentence.chars.len() - last_boundary_idx).unwrap_or(0); - for self_weight in self_weights.iter() { - match self_weight.start_rel_position.cmp(&diff) { - Ordering::Greater => continue, - Ordering::Equal => { - for (t, s) in tag_score.iter_mut().zip(&self_weight.weight) { - *t += *s; - } + if let Some(&range_start) = range_start.as_ref() { + let token = sentence.text_substring(range_start, sentence.len()); + if let Some((token_id, tag_predictor)) = tag_predictor.get(token) { + scores.clear(); + scores.resize(tag_predictor.bias().len(), 0); + tag_predictor.bias().add_scores(&mut scores); + if let Some(scorer) = self.0.char_scorer.as_ref() { + debug_assert!(sentence.len() <= sentence.char_pma_states.len()); + // token_id is always smaller than tag_weight.len() because tag_predictor is + // created to contain such values in the new() function. + unsafe { + scorer.add_tag_scores(*token_id, sentence.len() - 1, sentence, &mut scores); + } + } + if let Some(scorer) = self.0.type_scorer.as_ref() { + debug_assert!(sentence.len() <= sentence.type_pma_states.len()); + // token_id is always smaller than tag_weight.len() because tag_predictor is + // created to contain such values in the new() function. + unsafe { + scorer.add_tag_scores(*token_id, sentence.len() - 1, sentence, &mut scores); } - Ordering::Less => (), } - break; + let i = sentence.len() - 1; + tag_predictor.predict(&scores, &mut sentence.tags[i * self.0.n_tags..]); } } - sentence - .tags - .last_mut() - .unwrap() - .replace(self.best_tag(&tag_score)); + } - sentence + /// Serializes the predictor into a Vec. + pub fn serialize_to_vec(&self) -> Result> { + let config = bincode::config::standard(); + let result = bincode::encode_to_vec(&self.0, config)?; + Ok(result) + } + + /// Deserializes a predictor from a given slice and returns a tuple of the predictor and the remaining slice. + /// + /// # Safety + /// + /// The given data must be a correct predictor exported by [`Predictor::serialize_to_vec()`] + /// function. + pub unsafe fn deserialize_from_slice_unchecked(data: &[u8]) -> Result<(Self, &[u8])> { + let config = bincode::config::standard(); + // Deserialization is unsafe because the automaton will not be verified. + let (predictor_data, size) = bincode::decode_from_slice(data, config)?; + Ok((Self(predictor_data), &data[size..])) } } @@ -336,801 +626,288 @@ impl Predictor { mod tests { use super::*; - use alloc::string::ToString; - use crate::dict_model::{DictModel, WordWeightRecord}; - use crate::ngram_model::{NgramData, NgramModel}; - use crate::sentence::CharacterType::*; - use crate::tag_model::TagModel; + use crate::model::TagModel; + use crate::ngram_model::{NgramData, NgramModel, TagNgramData, TagNgramModel, TagWeight}; + use crate::CharacterBoundary::*; + use crate::CharacterType::*; - #[cfg(feature = "tag-prediction")] - use crate::sentence::Token; - #[cfg(feature = "tag-prediction")] - use crate::tag_model::TagClassInfo; - - /// Input: 我 ら は 全 世 界 の 国 民 - /// bias: -200 .. .. .. .. .. .. .. - /// words: - /// 我ら: 3 4 5 - /// 全世界: 6 7 8 9 - /// 国民: 10 11 12 - /// 世界: 15 16 17 18 19 - /// 界: 20 21 22 23 24 25 - /// types: - /// H: 27 28 29 - /// 26 27 28 29 - /// 26 27 28 29 - /// K: 32 33 - /// 30 31 32 33 - /// 30 31 32 33 - /// 30 31 32 33 - /// 30 31 32 - /// 30 31 - /// KH: 35 36 - /// 34 35 36 - /// HK: 37 38 39 - /// 37 38 39 - /// dict: - /// 全世界: 43 44 44 45 - /// 世界: 43 44 45 - /// 世: 40 42 - fn generate_model_1() -> Model { - Model::new( - NgramModel { - data: vec![ - NgramData { - ngram: "我ら".to_string(), - weights: vec![1, 2, 3, 4, 5], - }, - NgramData { - ngram: "全世界".to_string(), - weights: vec![6, 7, 8, 9], - }, - NgramData { - ngram: "国民".to_string(), - weights: vec![10, 11, 12, 13, 14], - }, - NgramData { - ngram: "世界".to_string(), - weights: vec![15, 16, 17, 18, 19], - }, - NgramData { - ngram: "界".to_string(), - weights: vec![20, 21, 22, 23, 24, 25], - }, - ], - }, - NgramModel { - data: vec![ - NgramData { - ngram: vec![Hiragana as u8], - weights: vec![26, 27, 28, 29], - }, - NgramData { - ngram: vec![Kanji as u8], - weights: vec![30, 31, 32, 33], - }, - NgramData { - ngram: vec![Kanji as u8, Hiragana as u8], - weights: vec![34, 35, 36], - }, - NgramData { - ngram: vec![Hiragana as u8, Kanji as u8], - weights: vec![37, 38, 39], - }, - ], - }, - DictModel { - dict: vec![ - WordWeightRecord { - word: "全世界".to_string(), - weights: vec![43, 44, 44, 45], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世界".to_string(), - weights: vec![43, 44, 45], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世".to_string(), - weights: vec![40, 42], - comment: "".to_string(), - }, - ], - }, - -200, - 3, - 2, - TagModel::default(), - ) + #[test] + fn test_positional_weight_add_assign_1() { + let mut y = PositionalWeight::new(-2, vec![1, 2, 3, 4]); + let x = PositionalWeight::new(4, vec![2, 4, 8]); + y += &x; + assert_eq!(-2, y.offset); + assert_eq!(vec![1, 2, 3, 4, 0, 0, 2, 4, 8], y.weight); } - /// Input: 我 ら は 全 世 界 の 国 民 - /// bias: -285 .. .. .. .. .. .. .. - /// words: - /// 我ら: 2 3 - /// 全世界: 4 5 - /// 国民: 6 7 - /// 世界: 9 10 11 - /// 界: 12 13 14 15 - /// types: - /// H: 18 19 20 21 - /// 17 18 19 20 21 - /// 16 17 18 19 20 - /// K: 25 26 27 - /// 22 23 24 25 26 27 - /// 22 23 24 25 26 27 - /// 22 23 24 25 26 27 - /// 22 23 24 25 - /// 22 23 24 - /// KH: 30 31 32 - /// 28 29 30 31 32 - /// HK: 33 34 35 36 37 - /// 33 34 35 36 - /// dict: - /// 全世界: 44 45 45 46 - /// 世界: 41 42 43 - /// 世: 38 40 - fn generate_model_2() -> Model { - Model::new( - NgramModel { - data: vec![ - NgramData { - ngram: "我ら".to_string(), - weights: vec![1, 2, 3], - }, - NgramData { - ngram: "全世界".to_string(), - weights: vec![4, 5], - }, - NgramData { - ngram: "国民".to_string(), - weights: vec![6, 7, 8], - }, - NgramData { - ngram: "世界".to_string(), - weights: vec![9, 10, 11], - }, - NgramData { - ngram: "界".to_string(), - weights: vec![12, 13, 14, 15], - }, - ], - }, - NgramModel { - data: vec![ - NgramData { - ngram: vec![Hiragana as u8], - weights: vec![16, 17, 18, 19, 20, 21], - }, - NgramData { - ngram: vec![Kanji as u8], - weights: vec![22, 23, 24, 25, 26, 27], - }, - NgramData { - ngram: vec![Kanji as u8, Hiragana as u8], - weights: vec![28, 29, 30, 31, 32], - }, - NgramData { - ngram: vec![Hiragana as u8, Kanji as u8], - weights: vec![33, 34, 35, 36, 37], - }, - ], - }, - DictModel { - dict: vec![ - WordWeightRecord { - word: "全世界".to_string(), - weights: vec![44, 45, 45, 46], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世界".to_string(), - weights: vec![41, 42, 43], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世".to_string(), - weights: vec![38, 40], - comment: "".to_string(), - }, - ], - }, - -285, - 2, - 3, - TagModel::default(), - ) + #[test] + fn test_positional_weight_add_assign_2() { + let mut y = PositionalWeight::new(-2, vec![1, 2, 3, 4]); + let x = PositionalWeight::new(2, vec![2, 4, 8]); + y += &x; + assert_eq!(-2, y.offset); + assert_eq!(vec![1, 2, 3, 4, 2, 4, 8], y.weight); } - /// Input: 我 ら は 全 世 界 の 国 民 - /// bias: -285 .. .. .. .. .. .. .. - /// words: - /// 我ら: 2 3 - /// 全世界: 4 5 - /// 国民: 6 7 - /// 世界: 9 10 11 - /// 界: 12 13 14 15 - /// types: - /// H: 18 19 20 21 - /// 17 18 19 20 21 - /// 16 17 18 19 20 - /// K: 25 26 27 - /// 22 23 24 25 26 27 - /// 22 23 24 25 26 27 - /// 22 23 24 25 26 27 - /// 22 23 24 25 - /// 22 23 24 - /// KH: 30 31 32 - /// 28 29 30 31 32 - /// HK: 33 34 35 36 37 - /// 33 34 35 36 - /// dict: - /// 国民: 38 39 - /// 世界: 41 42 43 - /// 世: 44 46 - fn generate_model_3() -> Model { - Model::new( - NgramModel { - data: vec![ - NgramData { - ngram: "我ら".to_string(), - weights: vec![1, 2, 3], - }, - NgramData { - ngram: "全世界".to_string(), - weights: vec![4, 5], - }, - NgramData { - ngram: "国民".to_string(), - weights: vec![6, 7, 8], - }, - NgramData { - ngram: "世界".to_string(), - weights: vec![9, 10, 11], - }, - NgramData { - ngram: "界".to_string(), - weights: vec![12, 13, 14, 15], - }, - ], - }, - NgramModel { - data: vec![ - NgramData { - ngram: vec![Hiragana as u8], - weights: vec![16, 17, 18, 19, 20, 21], - }, - NgramData { - ngram: vec![Kanji as u8], - weights: vec![22, 23, 24, 25, 26, 27], - }, - NgramData { - ngram: vec![Kanji as u8, Hiragana as u8], - weights: vec![28, 29, 30, 31, 32], - }, - NgramData { - ngram: vec![Hiragana as u8, Kanji as u8], - weights: vec![33, 34, 35, 36, 37], - }, - ], - }, - DictModel { - dict: vec![ - WordWeightRecord { - word: "国民".to_string(), - weights: vec![38, 39, 40], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世界".to_string(), - weights: vec![41, 42, 43], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世".to_string(), - weights: vec![44, 46], - comment: "".to_string(), - }, - ], - }, - -285, - 2, - 3, - TagModel::default(), - ) + #[test] + fn test_positional_weight_add_assign_3() { + let mut y = PositionalWeight::new(-2, vec![1, 2, 3, 4]); + let x = PositionalWeight::new(0, vec![2, 4, 8]); + y += &x; + assert_eq!(-2, y.offset); + assert_eq!(vec![1, 2, 5, 8, 8], y.weight); } - /// Input: 我 ら は 全 世 界 の 国 民 - /// bias: -200 .. .. .. .. .. .. .. - /// chars: - /// 我ら: 3 4 5 - /// 全世界: 6 7 8 9 - /// 国民: 10 11 12 - /// 世界: 15 16 17 18 19 - /// 界: 20 21 22 23 24 25 - /// types: - /// H: 27 28 29 - /// 26 27 28 29 - /// 26 27 28 29 - /// K: 32 33 - /// 30 31 32 33 - /// 30 31 32 33 - /// 30 31 32 33 - /// 30 31 32 - /// 30 31 - /// KH: 35 36 - /// 34 35 36 - /// HK: 37 38 39 - /// 37 38 39 - /// dict: - /// 全世界: 43 44 44 45 - /// 世界: 43 44 45 - /// 世: 40 42 - /// 世界の国民: 43 44 44 44 44 - /// は全世界: 43 44 44 44 45 - /// - /// - /// は全世界: 43 44 44 44 45 - /// 15 16 17 18 19 - /// 20 21 22 23 24 25 - /// 6 7 8 9 - fn generate_model_4() -> Model { - Model::new( - NgramModel { - data: vec![ - NgramData { - ngram: "我ら".to_string(), - weights: vec![1, 2, 3, 4, 5], - }, - NgramData { - ngram: "全世界".to_string(), - weights: vec![6, 7, 8, 9], - }, - NgramData { - ngram: "国民".to_string(), - weights: vec![10, 11, 12, 13, 14], - }, - NgramData { - ngram: "世界".to_string(), - weights: vec![15, 16, 17, 18, 19], - }, - NgramData { - ngram: "界".to_string(), - weights: vec![20, 21, 22, 23, 24, 25], - }, - ], - }, - NgramModel { - data: vec![ - NgramData { - ngram: vec![Hiragana as u8], - weights: vec![26, 27, 28, 29], - }, - NgramData { - ngram: vec![Kanji as u8], - weights: vec![30, 31, 32, 33], - }, - NgramData { - ngram: vec![Kanji as u8, Hiragana as u8], - weights: vec![34, 35, 36], - }, - NgramData { - ngram: vec![Hiragana as u8, Kanji as u8], - weights: vec![37, 38, 39], - }, - ], - }, - DictModel { - dict: vec![ - WordWeightRecord { - word: "全世界".to_string(), - weights: vec![43, 44, 44, 45], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世界".to_string(), - weights: vec![43, 44, 45], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世".to_string(), - weights: vec![40, 42], - comment: "".to_string(), - }, - WordWeightRecord { - word: "世界の国民".to_string(), - weights: vec![43, 44, 44, 44, 44, 45], - comment: "".to_string(), - }, - WordWeightRecord { - word: "は全世界".to_string(), - weights: vec![43, 44, 44, 44, 45], - comment: "".to_string(), - }, - ], - }, - -200, - 3, - 2, - TagModel::default(), - ) + #[test] + fn test_positional_weight_add_assign_4() { + let mut y = PositionalWeight::new(-2, vec![1, 2, 3, 4]); + let x = PositionalWeight::new(-1, vec![2, 4, 8]); + y += &x; + assert_eq!(-2, y.offset); + assert_eq!(vec![1, 4, 7, 12], y.weight); } - /// Input: 人 と 人 を つ な ぐ 人 - /// left: - /// \0人: 1 4 - /// 2 5 - /// 3 6 - /// 人: 7 10 7 10 - /// 8 11 8 11 - /// 9 12 9 12 - /// つなぐ: 13 16 19 - /// 14 17 20 - /// 15 18 21 - /// 人\0: 22 - /// 23 - /// 24 - /// - /// sum: 1 11 10 7 10 13 16 41 - /// 2 13 11 8 11 14 17 43 - /// 3 15 12 9 12 15 18 45 - /// - /// right: - /// \0人と: 28 - /// 29 - /// 30 - /// 人を: 31 34 37 - /// 32 35 38 - /// 33 36 39 - /// を: 40 43 - /// 41 44 - /// 42 45 - /// 人\0: 46 49 - /// 47 50 - /// 48 51 - /// - /// sum: 28 71 77 37 0 0 46 49 - /// 29 73 79 38 0 0 47 50 - /// 30 75 81 39 0 0 48 51 - #[cfg(feature = "tag-prediction")] - fn generate_model_5() -> Model { - Model::new( - NgramModel { - data: vec![NgramData { - ngram: "xxxx".to_string(), - weights: vec![0], - }], - }, - NgramModel { - data: vec![NgramData { - ngram: vec![Roman as u8, Roman as u8, Roman as u8, Roman as u8], - weights: vec![0], - }], - }, - DictModel { dict: vec![] }, - 0, - 2, - 2, - TagModel { - class_info: vec![ - TagClassInfo { - name: "名詞".to_string(), - bias: 5, - }, - TagClassInfo { - name: "動詞".to_string(), - bias: 3, - }, - TagClassInfo { - name: "助詞".to_string(), - bias: 1, - }, - ], - left_char_model: NgramModel { - data: vec![ - NgramData { - ngram: "\0人".to_string(), - weights: vec![1, 2, 3, 4, 5, 6], - }, - NgramData { - ngram: "人".to_string(), - weights: vec![7, 8, 9, 10, 11, 12], - }, - NgramData { - ngram: "つなぐ".to_string(), - weights: vec![13, 14, 15, 16, 17, 18, 19, 20, 21], - }, - NgramData { - ngram: "ぐ人\0".to_string(), - weights: vec![22, 23, 24], - }, - ], - }, - right_char_model: NgramModel { - data: vec![ - NgramData { - ngram: "\0人と".to_string(), - weights: vec![25, 26, 27, 28, 29, 30], - }, - NgramData { - ngram: "人を".to_string(), - weights: vec![31, 32, 33, 34, 35, 36, 37, 38, 39], - }, - NgramData { - ngram: "を".to_string(), - weights: vec![40, 41, 42, 43, 44, 45], - }, - NgramData { - ngram: "人\0".to_string(), - weights: vec![46, 47, 48, 49, 50, 51], - }, - ], - }, - self_char_model: NgramModel { - data: vec![ - NgramData { - ngram: "人".to_string(), - weights: vec![2, -1, -1], - }, - NgramData { - ngram: "と".to_string(), - weights: vec![0, 0, 0], - }, - NgramData { - ngram: "つなぐ".to_string(), - weights: vec![0, 1, 0], - }, - NgramData { - ngram: "を".to_string(), - weights: vec![0, 0, 0], - }, - ], - }, - }, - ) + #[test] + fn test_positional_weight_add_assign_5() { + let mut y = PositionalWeight::new(-2, vec![1, 2, 3, 4]); + let x = PositionalWeight::new(-2, vec![2, 4, 8]); + y += &x; + assert_eq!(-2, y.offset); + assert_eq!(vec![3, 6, 11, 4], y.weight); } #[test] - fn test_predict_1() { - let model = generate_model_1(); - let p = Predictor::new(model, false).unwrap(); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict(s); - assert_eq!( - &[ - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - ], - s.boundaries(), - ); + fn test_positional_weight_add_assign_6() { + let mut y = PositionalWeight::new(-2, vec![1, 2, 3, 4]); + let x = PositionalWeight::new(-4, vec![2, 4, 8]); + y += &x; + assert_eq!(-4, y.offset); + assert_eq!(vec![2, 4, 9, 2, 3, 4], y.weight); } #[test] - fn test_predict_2() { - let model = generate_model_2(); - let p = Predictor::new(model, false).unwrap(); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict(s); - assert_eq!( - &[ - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - ], - s.boundaries(), - ); + fn test_positional_weight_add_assign_7() { + let mut y = PositionalWeight::new(-2, vec![1, 2, 3, 4]); + let x = PositionalWeight::new(-5, vec![2, 4, 8]); + y += &x; + assert_eq!(-5, y.offset); + assert_eq!(vec![2, 4, 8, 1, 2, 3, 4], y.weight); } #[test] - fn test_predict_3() { - let model = generate_model_3(); - let p = Predictor::new(model, false).unwrap(); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict(s); - assert_eq!( - &[ - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - ], - s.boundaries(), - ); + fn test_positional_weight_add_assign_8() { + let mut y = PositionalWeight::new(-2, vec![1, 2, 3, 4]); + let x = PositionalWeight::new(-7, vec![2, 4, 8]); + y += &x; + assert_eq!(-7, y.offset); + assert_eq!(vec![2, 4, 8, 0, 0, 1, 2, 3, 4], y.weight); } - #[test] - fn test_predict_4() { - let model = generate_model_4(); - let p = Predictor::new(model, false).unwrap(); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict(s); - assert_eq!( - &[ - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, + fn create_test_model() -> Model { + // input: こ の 人 は 地 球 人 だ + // n-grams: + // この人: -2 3 4 + // 人だ: -5 6 7 + // n-grams: + // HHK: -11 12 13 + // KH: -14 15 16 17 -18 + // -14 15 16 + // dict: + // 人: 19 20 19 20 + // 地球: 21 -22 23 + Model::new( + NgramModel(vec![ + NgramData { + ngram: "この人".into(), + weights: vec![1, -2, 3, 4], + }, + NgramData { + ngram: "人だ".into(), + weights: vec![-5, 6, 7, 8, 9], + }, + ]), + NgramModel(vec![ + NgramData { + ngram: vec![Hiragana as u8, Hiragana as u8, Kanji as u8], + weights: vec![10, -11, 12, 13], + }, + NgramData { + ngram: vec![Kanji as u8, Hiragana as u8], + weights: vec![-14, 15, 16, 17, -18], + }, + ]), + DictModel(vec![ + WordWeightRecord { + word: "人".into(), + weights: vec![19, 20], + comment: "".into(), + }, + WordWeightRecord { + word: "地球".into(), + weights: vec![21, -22, 23], + comment: "".into(), + }, + ]), + 5, + 3, + 3, + vec![ + TagModel { + token: "人".into(), + tags: vec![ + vec!["名詞".into(), "接尾辞".into()], + vec!["ジン".into(), "ヒト".into()], + ], + char_ngram_model: TagNgramModel(vec![TagNgramData { + ngram: "は地球人".into(), + weights: vec![TagWeight { + rel_position: 0, + weights: vec![-32, 33, 34, -35], + }], + }]), + type_ngram_model: TagNgramModel(vec![TagNgramData { + ngram: vec![Hiragana as u8, Kanji as u8, Hiragana as u8], + weights: vec![TagWeight { + rel_position: 1, + weights: vec![36, -37, -38, 39], + }], + }]), + bias: vec![40, 41, 42, 43], + }, + TagModel { + token: "地球".into(), + tags: vec![ + vec!["名詞".into()], + vec!["マンホーム".into(), "チキュー".into()], + ], + char_ngram_model: TagNgramModel(vec![TagNgramData { + ngram: "は地球人".into(), + weights: vec![TagWeight { + rel_position: 1, + weights: vec![-44, 45], + }], + }]), + type_ngram_model: TagNgramModel(vec![]), + bias: vec![46, 47], + }, ], - s.boundaries(), - ); + ) } #[test] - fn test_predict_with_score_1() { - let model = generate_model_1(); - let p = Predictor::new(model, false).unwrap(); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_with_score(s); - assert_eq!(&[-77, -5, 45, 132, 133, 144, 50, -32], s.boundary_scores(),); + fn test_predict_boundaries() { + let model = create_test_model(); + let predictor = Predictor::new(model, false).unwrap(); + let mut sentence = Sentence::from_raw("この人は地球人だ").unwrap(); + predictor.predict(&mut sentence); + assert_eq!(&[-22, 54, 58, 43, -54, 68, 48], sentence.boundary_scores(),); assert_eq!( &[ - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary ], - s.boundaries(), + sentence.boundaries(), ); } + #[cfg(feature = "tag-prediction")] #[test] - fn test_predict_with_score_2() { - let model = generate_model_2(); - let p = Predictor::new(model, false).unwrap(); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_with_score(s); - assert_eq!( - &[-138, -109, -39, 57, 104, 34, -79, -114], - s.boundary_scores(), - ); + fn test_predict_tags() { + let model = create_test_model(); + let predictor = Predictor::new(model, true).unwrap(); + let mut sentence = Sentence::from_raw("この人は地球人だ").unwrap(); + predictor.predict(&mut sentence); + sentence.fill_tags(); + assert_eq!(&[-22, 54, 58, 43, -54, 68, 48], sentence.boundary_scores(),); assert_eq!( &[ - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, ], - s.boundaries(), - ); - } - - #[test] - fn test_predict_with_score_3() { - let model = generate_model_3(); - let p = Predictor::new(model, false).unwrap(); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_with_score(s); - assert_eq!( - &[-138, -109, -83, 18, 65, -12, -41, -75], - s.boundary_scores(), + sentence.boundaries(), ); assert_eq!( &[ - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, + None, + None, + None, + None, + Some(Cow::Borrowed("名詞")), + Some(Cow::Borrowed("ヒト")), + None, + None, + None, + None, + Some(Cow::Borrowed("名詞")), + Some(Cow::Borrowed("チキュー")), + Some(Cow::Borrowed("接尾辞")), + Some(Cow::Borrowed("ジン")), + None, + None, ], - s.boundaries(), + sentence.tags() ); } #[test] - fn test_predict_with_score_4() { - let model = generate_model_4(); - let p = Predictor::new(model, false).unwrap(); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_with_score(s); - assert_eq!(&[-77, 38, 89, 219, 221, 233, 94, 12], s.boundary_scores(),); + fn test_serialization() { + let model = create_test_model(); + let predictor = Predictor::new(model, false).unwrap(); + let data = predictor.serialize_to_vec().unwrap(); + let (predictor, _) = unsafe { Predictor::deserialize_from_slice_unchecked(&data).unwrap() }; + let mut sentence = Sentence::from_raw("この人は地球人だ").unwrap(); + predictor.predict(&mut sentence); + assert_eq!(&[-22, 54, 58, 43, -54, 68, 48], sentence.boundary_scores(),); assert_eq!( &[ - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary ], - s.boundaries(), + sentence.boundaries(), ); } #[cfg(feature = "tag-prediction")] #[test] - fn test_predict_with_score_5() { - let model = generate_model_5(); - let p = Predictor::new(model, true).unwrap(); - let s = Sentence::from_raw("人と人をつなぐ人").unwrap(); - let mut s = p.predict(s); - assert_eq!( - &[ - 1, 2, 3, 11, 13, 15, 10, 11, 12, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 41, - 43, 45 - ], - s.tag_scores.left_scores.as_slice() - ); - assert_eq!( - &[ - 28, 29, 30, 71, 73, 75, 77, 79, 81, 37, 38, 39, 0, 0, 0, 0, 0, 0, 46, 47, 48, 49, - 50, 51 - ], - s.tag_scores.right_scores.as_slice() - ); - - s.boundaries_mut().copy_from_slice(&[ - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - ]); - let s = p.fill_tags(s); + #[should_panic] + fn test_fill_tags_unsupported() { + let model = create_test_model(); + let predictor = Predictor::new(model, false).unwrap(); + let mut sentence = Sentence::from_raw("この人は地球人だ").unwrap(); + predictor.predict(&mut sentence); + sentence.fill_tags(); + } - assert_eq!( - vec![ - Token { - surface: "人", - tag: Some("名詞") - }, - Token { - surface: "と", - tag: Some("助詞") - }, - Token { - surface: "人", - tag: Some("名詞") - }, - Token { - surface: "を", - tag: Some("助詞") - }, - Token { - surface: "つなぐ", - tag: Some("動詞") - }, - Token { - surface: "人", - tag: Some("名詞") - } - ], - s.to_tokenized_vec().unwrap(), - ); + #[cfg(feature = "tag-prediction")] + #[test] + #[should_panic] + fn test_fill_tags_unsupported_overwrite_prediction() { + let mut sentence = Sentence::from_raw("この人は地球人だ").unwrap(); + + let model = create_test_model(); + let predictor = Predictor::new(model, true).unwrap(); + predictor.predict(&mut sentence); + sentence.fill_tags(); + + let model = create_test_model(); + let predictor = Predictor::new(model, false).unwrap(); + predictor.predict(&mut sentence); + sentence.fill_tags(); } } diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index ccf961ac..949f159a 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -1,10 +1,9 @@ -use alloc::string::{String, ToString}; -use alloc::sync::Arc; +use alloc::borrow::Cow; +use alloc::string::String; use alloc::vec::Vec; -use bincode::{Decode, Encode}; - use crate::errors::{Result, VaporettoError}; +use crate::predictor::Predictor; /// Character type. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -69,9 +68,9 @@ impl CharacterType { } /// Boundary type. -#[derive(Debug, PartialEq, Clone, Copy)] +#[derive(Debug, Eq, PartialEq, Clone, Copy)] #[repr(u8)] -pub enum BoundaryType { +pub enum CharacterBoundary { /// Inner of a word. NotWordBoundary = 0, @@ -82,164 +81,208 @@ pub enum BoundaryType { Unknown = 2, } -/// Token information. -#[derive(Debug, PartialEq, Clone)] -pub struct Token<'a> { - /// A surface of this token. - pub surface: &'a str, - - /// A part-of-speech tag of this token. - pub tag: Option<&'a str>, -} - -/// Weight array with the corresponding range. -/// -/// This data is placed on the end of each range. -#[derive(Debug, PartialEq, Clone, Decode, Encode)] -pub struct TagRangeScore { - /// Weight array. - pub weight: Vec, - - /// The relative position of the start position from the end position. - pub start_rel_position: i16, -} - -impl TagRangeScore { - #[allow(dead_code)] - pub fn new(start_rel_position: i16, weight: Vec) -> Self { - Self { - start_rel_position, - weight, - } - } -} - -pub type TagRangeScores = Arc>; - -#[derive(Debug, PartialEq, Clone, Default)] -pub struct TagScores { - pub left_scores: Vec, - pub right_scores: Vec, - pub self_scores: Vec>, +/// Sentence data containing boundary and tag annotations. +pub struct Sentence<'a, 'b> { + pub(crate) text: Cow<'a, str>, + pub(crate) char_types: Vec, + pub(crate) boundaries: Vec, + pub(crate) boundary_scores: Vec, + pub(crate) score_padding: usize, + pub(crate) char_pma_states: Vec, + pub(crate) type_pma_states: Vec, + pub(crate) tags: Vec>>, + pub(crate) n_tags: usize, + predictor: Option<&'b Predictor>, + str_to_char_pos: Vec, + char_to_str_pos: Vec, } -impl TagScores { - /// Clears scores. - pub fn clear(&mut self) { - self.left_scores.clear(); - self.right_scores.clear(); - self.self_scores.clear(); - } - - /// Initializes score arrays. +impl<'a, 'b> Default for Sentence<'a, 'b> { + /// Creates a new [`Sentence`] consisting of a space. /// - /// # Arguments + /// # Examples /// - /// * `n_chars` - Length of characters in code points. - /// * `n_tags` - The number of tags. - #[allow(dead_code)] - pub fn init(&mut self, n_chars: usize, n_tags: usize) { - self.clear(); - self.left_scores.resize(n_chars * n_tags, 0); - self.right_scores.resize(n_chars * n_tags, 0); - self.self_scores.resize(n_chars, None); - } -} - -/// Sentence with boundary annotations. -#[derive(Debug, PartialEq, Clone)] -pub struct Sentence { - pub(crate) text: String, - pub(crate) chars: Vec, - pub(crate) str_to_char_pos: Vec, - pub(crate) char_to_str_pos: Vec, - pub(crate) char_type: Vec, - pub(crate) boundaries: Vec, - pub(crate) boundary_scores: Vec, - pub(crate) tag_scores: TagScores, - pub(crate) tags: Vec>>, -} - -impl Sentence { - fn internal_new( - text: String, - chars: Vec, - boundaries: Vec, - tags: Vec>>, - ) -> Self { + /// ``` + /// use vaporetto::Sentence; + /// + /// let s = Sentence::default(); + /// + /// assert_eq!(" ", s.as_raw_text()); + /// assert_eq!(0, s.n_tags()); + /// ``` + fn default() -> Self { let mut s = Self { - text, - chars, + text: Cow::Borrowed(""), + char_types: vec![], + boundaries: vec![], + boundary_scores: vec![], + score_padding: 0, + char_pma_states: vec![], + type_pma_states: vec![], + tags: vec![], + n_tags: 0, + predictor: None, str_to_char_pos: vec![], char_to_str_pos: vec![], - char_type: vec![], - boundaries, - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags, }; - s.update_common_info(); + s.set_default(); s } +} - fn clear(&mut self) { - self.text.clear(); - self.text.push(' '); - self.chars.clear(); - self.chars.push(' '); +impl<'a, 'b> Sentence<'a, 'b> { + #[inline(always)] + fn set_default(&mut self) { + self.text = Cow::Borrowed(" "); + self.char_types.clear(); + self.char_types.push(CharacterType::Other as u8); + self.boundaries.clear(); + self.boundary_scores.clear(); + self.score_padding = 0; + self.char_pma_states.clear(); + self.type_pma_states.clear(); + self.tags.clear(); + self.n_tags = 0; + self.predictor.take(); self.str_to_char_pos.clear(); self.str_to_char_pos.push(0); self.str_to_char_pos.push(1); self.char_to_str_pos.clear(); self.char_to_str_pos.push(0); self.char_to_str_pos.push(1); - self.char_type.clear(); - self.char_type.push(CharacterType::Other as u8); - self.boundaries.clear(); - self.boundary_scores.clear(); - self.tag_scores.clear(); - self.tags.clear(); - self.tags.push(None); } - fn parse_raw_text( - raw_text: &str, - chars: &mut Vec, - boundaries: &mut Vec, - tags: &mut Vec>>, + fn parse_raw( + text: &str, + char_types: &mut Vec, + boundaries: &mut Vec, + str_to_char_pos: &mut Vec, + char_to_str_pos: &mut Vec, ) -> Result<()> { - if raw_text.is_empty() { - return Err(VaporettoError::invalid_argument( - "raw_text", - "must contain at least one character", - )); - } - - chars.clear(); - - for c in raw_text.chars() { + char_types.clear(); + boundaries.clear(); + str_to_char_pos.clear(); + char_to_str_pos.clear(); + char_to_str_pos.push(0); + let mut pos = 0; + for c in text.chars() { if c == '\0' { return Err(VaporettoError::invalid_argument( - "raw_text", + "text", "must not contain NULL", )); } - chars.push(c); + char_types.push(CharacterType::get_type(c) as u8); + pos += c.len_utf8(); + char_to_str_pos.push(pos); } - boundaries.clear(); - boundaries.resize(chars.len() - 1, BoundaryType::Unknown); - tags.clear(); - tags.resize(chars.len(), None); + if char_types.is_empty() { + return Err(VaporettoError::invalid_argument( + "text", + "must contain at least one character", + )); + } + str_to_char_pos.resize(pos + 1, 0); + for (i, &pos) in char_to_str_pos.iter().enumerate() { + str_to_char_pos[pos] = i; + } + boundaries.resize(char_types.len() - 1, CharacterBoundary::Unknown); + Ok(()) + } + + /// Creates a new [`Sentence`] from a given text without any annotation. + /// + /// # Errors + /// + /// If the given `text` is empty, an error variant will be returned. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let s = Sentence::from_raw("まぁ良いだろう").unwrap(); + /// let mut buf = String::new(); + /// s.write_partial_annotation_text(&mut buf); + /// assert_eq!("ま ぁ 良 い だ ろ う", buf); + /// + /// let s = Sentence::from_raw(""); + /// assert!(s.is_err()); + /// ``` + pub fn from_raw(text: impl Into>) -> Result { + let text = text.into(); + let mut char_types = vec![]; + let mut boundaries = vec![]; + let mut str_to_char_pos = vec![]; + let mut char_to_str_pos = vec![]; + Self::parse_raw( + &text, + &mut char_types, + &mut boundaries, + &mut str_to_char_pos, + &mut char_to_str_pos, + )?; + Ok(Self { + text, + char_types, + boundaries, + boundary_scores: vec![], + score_padding: 0, + char_pma_states: vec![], + type_pma_states: vec![], + predictor: None, + tags: vec![], + n_tags: 0, + str_to_char_pos, + char_to_str_pos, + }) + } + /// Updates the [`Sentence`] using a given text without any annotation. + /// + /// # Errors + /// + /// If the given `text` is empty, an error variant will be returned. + /// When an error is occurred, the sentence will be replaced with a white space. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let mut s = Sentence::from_raw("まぁ良いだろう").unwrap(); + /// s.update_raw("まぁ社長は火星猫だ").unwrap(); + /// assert_eq!("まぁ社長は火星猫だ", s.as_raw_text()); + /// ``` + pub fn update_raw(&mut self, text: impl Into>) -> Result<()> { + self.text = text.into(); + if let Err(e) = Self::parse_raw( + &self.text, + &mut self.char_types, + &mut self.boundaries, + &mut self.str_to_char_pos, + &mut self.char_to_str_pos, + ) { + self.set_default(); + return Err(e); + } + self.boundary_scores.clear(); + self.score_padding = 0; + self.char_pma_states.clear(); + self.type_pma_states.clear(); + self.predictor.take(); + self.tags.clear(); Ok(()) } - fn parse_tokenized_text( + fn parse_tokenized( tokenized_text: &str, text: &mut String, - chars: &mut Vec, - boundaries: &mut Vec, - tags: &mut Vec>>, + char_types: &mut Vec, + boundaries: &mut Vec, + str_to_char_pos: &mut Vec, + char_to_str_pos: &mut Vec, + tags: &mut Vec>>, ) -> Result<()> { if tokenized_text.is_empty() { return Err(VaporettoError::invalid_argument( @@ -247,17 +290,17 @@ impl Sentence { "must contain at least one character", )); } - text.clear(); - text.reserve(tokenized_text.len()); - chars.clear(); + char_types.clear(); boundaries.clear(); - tags.clear(); - - let mut tag_str_tmp = None; + str_to_char_pos.clear(); + char_to_str_pos.clear(); + char_to_str_pos.push(0); let mut tag_str = None; let mut prev_boundary = false; let mut escape = false; + let mut tags_tmp: Vec> = vec![]; + let mut pos = 0; for c in tokenized_text.chars() { match (escape, c) { // escape a following character @@ -266,7 +309,7 @@ impl Sentence { } // token boundary (false, ' ') => { - if chars.is_empty() { + if text.is_empty() { return Err(VaporettoError::invalid_argument( "tokenized_text", "must not start with a whitespace", @@ -278,456 +321,501 @@ impl Sentence { "must not contain consecutive whitespaces", )); } + if let Some(tag) = tag_str.take() { + tags_tmp.last_mut().unwrap().push(tag); + } prev_boundary = true; - tag_str = tag_str_tmp.take(); } - // POS tag + // tag (false, '/') => { - if chars.is_empty() || prev_boundary { + if text.is_empty() || prev_boundary { return Err(VaporettoError::invalid_argument( "tokenized_text", "a slash must follow a character", )); } - if tag_str_tmp.is_some() { - return Err(VaporettoError::invalid_argument( - "tokenized_text", - "invalid slash found", - )); + if let Some(tag) = tag_str.replace(String::new()) { + tags_tmp.last_mut().unwrap().push(tag); } - tag_str_tmp.replace("".to_string()); } // escaped character or other character (_, _) => { - if let Some(tag) = tag_str_tmp.as_mut() { + escape = false; + if c == '\0' { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "must not contain NULL", + )); + } + if let Some(tag) = tag_str.as_mut() { tag.push(c); continue; } - if !chars.is_empty() { + if !text.is_empty() { boundaries.push(if prev_boundary { - BoundaryType::WordBoundary + CharacterBoundary::WordBoundary } else { - BoundaryType::NotWordBoundary + CharacterBoundary::NotWordBoundary }); - tags.push(tag_str.take().map(Arc::new)); - } - if c == '\0' { - return Err(VaporettoError::invalid_argument( - "tokenized_text", - "must not contain NULL", - )); } prev_boundary = false; - escape = false; text.push(c); - chars.push(c); + char_types.push(CharacterType::get_type(c) as u8); + pos += c.len_utf8(); + char_to_str_pos.push(pos); + tags_tmp.push(vec![]); } }; } - if prev_boundary { return Err(VaporettoError::invalid_argument( "tokenized_text", "must not end with a whitespace", )); } - tags.push(tag_str_tmp.take().map(Arc::new)); + str_to_char_pos.resize(pos + 1, 0); + for (i, &pos) in char_to_str_pos.iter().enumerate() { + str_to_char_pos[pos] = i; + } + if let Some(tag) = tag_str.take() { + tags_tmp.last_mut().unwrap().push(tag); + } + let n_tags = tags_tmp.iter().fold(0, |acc, x| acc.max(x.len())); + tags.clear(); + for ts in tags_tmp { + let n_fill_none = n_tags - ts.len(); + for t in ts { + if t.is_empty() { + tags.push(None); + } else { + tags.push(Some(Cow::Owned(t))); + } + } + for _ in 0..n_fill_none { + tags.push(None); + } + } + Ok(()) + } + + /// Creates a new [`Sentence`] from a tokenized text. + /// + /// A tokenized text must be annotated by the following rules: + /// - A whitespace (`' '`) is inserted to each token boundary. + /// - If necessary, multiple tags following each slash (`'/'`) can be added to each token. + /// - Each character following a back slash (`'\\'`) is escaped. + /// + /// # Errors + /// + /// This function will return an error variant when the given text is empty, starts/ends with a + /// whitespace, or contains consecutive whitespaces. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let s = Sentence::from_tokenized("まぁ 社長 は 火星 猫 だ"); + /// assert_eq!("まぁ社長は火星猫だ", s.unwrap().as_raw_text()); + /// + /// let s = Sentence::from_tokenized("まぁ/名詞 社長/名詞 は/助詞 火星/名詞 猫/名詞 だ/助動詞"); + /// assert_eq!("まぁ社長は火星猫だ", s.unwrap().as_raw_text()); + /// + /// let s = Sentence::from_tokenized("まぁ 社長 は 火星 猫 だ"); + /// assert!(s.is_err()); + /// ``` + pub fn from_tokenized(tokenized_text: &str) -> Result { + let mut text = String::new(); + let mut char_types = vec![]; + let mut boundaries = vec![]; + let mut str_to_char_pos = vec![]; + let mut char_to_str_pos = vec![]; + let mut tags = vec![]; + Self::parse_tokenized( + tokenized_text, + &mut text, + &mut char_types, + &mut boundaries, + &mut str_to_char_pos, + &mut char_to_str_pos, + &mut tags, + )?; + let n_tags = tags.len() / char_types.len(); + Ok(Self { + text: Cow::Owned(text), + char_types, + boundaries, + boundary_scores: vec![], + score_padding: 0, + char_pma_states: vec![], + type_pma_states: vec![], + predictor: None, + tags, + n_tags, + str_to_char_pos, + char_to_str_pos, + }) + } + /// Updates the [`Sentence`] using a tokenized text. + /// + /// A tokenized text must be annotated by the following rules: + /// - A whitespace (`' '`) is inserted to each token boundary. + /// - If necessary, multiple tags following each slash (`'/'`) can be added to each token. + /// - Each character following a back slash (`'\\'`) is escaped. + /// + /// # Errors + /// + /// This function will return an error variant when the given text is empty, starts/ends with a + /// whitespace, or contains consecutive whitespaces. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let mut s = Sentence::default(); + /// + /// s.update_tokenized("まぁ 良い だろう").unwrap(); + /// assert_eq!("まぁ良いだろう", s.as_raw_text()); + /// + /// s.update_tokenized("まぁ/副詞/マー 良い/形容詞/ヨイ だろう/助動詞/ダロー").unwrap(); + /// assert_eq!("まぁ良いだろう", s.as_raw_text()); + /// ``` + pub fn update_tokenized(&mut self, tokenized_text: &str) -> Result<()> { + if let Err(e) = Self::parse_tokenized( + tokenized_text, + self.text.to_mut(), + &mut self.char_types, + &mut self.boundaries, + &mut self.str_to_char_pos, + &mut self.char_to_str_pos, + &mut self.tags, + ) { + self.set_default(); + return Err(e); + } + self.boundary_scores.clear(); + self.score_padding = 0; + self.char_pma_states.clear(); + self.type_pma_states.clear(); + self.predictor.take(); + self.n_tags = self.tags.len() / self.char_types.len(); Ok(()) } fn parse_partial_annotation( - labeled_text: &str, + partial_annotation_text: &str, text: &mut String, - chars: &mut Vec, - boundaries: &mut Vec, - tags: &mut Vec>>, + char_types: &mut Vec, + boundaries: &mut Vec, + str_to_char_pos: &mut Vec, + char_to_str_pos: &mut Vec, + tags: &mut Vec>>, ) -> Result<()> { - if labeled_text.is_empty() { + if partial_annotation_text.is_empty() { return Err(VaporettoError::invalid_argument( - "labeled_text", + "partial_annotation_text", "must contain at least one character", )); } - text.clear(); - chars.clear(); + char_types.clear(); boundaries.clear(); - tags.clear(); - + str_to_char_pos.clear(); + char_to_str_pos.clear(); + char_to_str_pos.push(0); let mut tag_str = None; + let mut escape = false; + let mut tags_tmp: Vec> = vec![]; + let mut pos = 0; let mut is_char = true; - let mut fixed_token = true; - for c in labeled_text.chars() { + for c in partial_annotation_text.chars() { if is_char { if c == '\0' { return Err(VaporettoError::invalid_argument( - "labeled_text", + "partial_annotation_text", "must not contain NULL", )); } text.push(c); - chars.push(c); + char_types.push(CharacterType::get_type(c) as u8); + pos += c.len_utf8(); + char_to_str_pos.push(pos); + tags_tmp.push(vec![]); is_char = false; continue; } - match c { - // unannotated boundary - ' ' => { - if tag_str.is_some() { - return Err(VaporettoError::invalid_argument( - "labeled_text", - "POS tag must be annotated to a token", - )); + match (escape, c) { + (false, '\\') => { + escape = true; + } + (false, ' ') => { + if let Some(tag) = tag_str.take() { + tags_tmp.last_mut().unwrap().push(tag); } - tags.push(None); - boundaries.push(BoundaryType::Unknown); + boundaries.push(CharacterBoundary::Unknown); is_char = true; - fixed_token = false; } - // token boundary - '|' => { - if !fixed_token && tag_str.is_some() { - return Err(VaporettoError::invalid_argument( - "labeled_text", - "POS tag must be annotated to a token", - )); + (false, '-') => { + if let Some(tag) = tag_str.take() { + tags_tmp.last_mut().unwrap().push(tag); } - tags.push(tag_str.take().map(Arc::new)); - boundaries.push(BoundaryType::WordBoundary); + boundaries.push(CharacterBoundary::NotWordBoundary); is_char = true; - fixed_token = true; } - // not token boundary - '-' => { - if tag_str.is_some() { - return Err(VaporettoError::invalid_argument( - "labeled_text", - "POS tag must be annotated to a token", - )); + (false, '|') => { + if let Some(tag) = tag_str.take() { + tags_tmp.last_mut().unwrap().push(tag); } - tags.push(None); - boundaries.push(BoundaryType::NotWordBoundary); + boundaries.push(CharacterBoundary::WordBoundary); is_char = true; } - // POS tag - '/' => { - if tag_str.is_some() { - return Err(VaporettoError::invalid_argument( - "labeled_text", - "invalid slash found", - )); + (false, '/') => { + let tag = tag_str.replace(String::new()); + if let Some(tag) = tag { + tags_tmp.last_mut().unwrap().push(tag); } - tag_str.replace("".to_string()); } _ => { + escape = false; if let Some(tag) = tag_str.as_mut() { tag.push(c); } else { return Err(VaporettoError::invalid_argument( - "labeled_text", + "partial_annotation_text", format!("contains an invalid boundary character: '{}'", c), )); } } } } - tags.push(tag_str.take().map(Arc::new)); - if chars.len() != boundaries.len() + 1 { + if is_char { return Err(VaporettoError::invalid_argument( - "labeled_text", + "partial_annotation_text", "invalid annotation", )); } - - Ok(()) - } - - /// Updates char_to_str_pos, str_to_char_pos, and char_type. - /// - /// This function allocates: - /// - /// * char_to_str_pos: chars.len() + 1 - /// * str_to_char_pos: text.len() + 1 - /// * char_type: chars.len() - /// - /// If these variables already have sufficient spaces, this function reuses them. - fn update_common_info(&mut self) { - self.char_to_str_pos.clear(); - self.str_to_char_pos.clear(); - self.char_type.clear(); - self.boundary_scores.clear(); - self.tag_scores.clear(); - - let mut pos = 0; - self.char_to_str_pos.push(0); - for &c in &self.chars { - pos += c.len_utf8(); - self.char_to_str_pos.push(pos); - self.char_type.push(CharacterType::get_type(c) as u8) + str_to_char_pos.resize(pos + 1, 0); + for (i, &pos) in char_to_str_pos.iter().enumerate() { + str_to_char_pos[pos] = i; } - - debug_assert!(pos == self.text.len()); - - self.str_to_char_pos.resize(self.text.len() + 1, 0); - for (i, &j) in self.char_to_str_pos.iter().enumerate() { - // j is always lower than pos + 1, so the following is safe. - unsafe { - *self.str_to_char_pos.get_unchecked_mut(j) = i; + if let Some(tag) = tag_str.take() { + tags_tmp.last_mut().unwrap().push(tag); + } + let n_tags = tags_tmp.iter().fold(0, |acc, x| acc.max(x.len())); + tags.clear(); + for ts in tags_tmp { + let n_fill_none = n_tags - ts.len(); + for t in ts { + if t.is_empty() { + tags.push(None); + } else { + tags.push(Some(Cow::Owned(t))); + } + } + for _ in 0..n_fill_none { + tags.push(None); } } + Ok(()) } - /// Creates a new [`Sentence`] from a given string. - /// - /// # Arguments + /// Creates a new [`Sentence`] from a text with partial annotations. /// - /// * `raw_text` - A raw string without any annotation. - /// - /// # Returns + /// Each character boundary must be annotated by the following rules: + /// - If the boundary is a token boundary, a pipe symbol (`'|'`) is inserted. + /// - If the boundary is not a token boundary, a dash symobl (`'-'`) is inserted. + /// - If the boundary is not annotated, a whitespace (`' '`) is inserted. /// - /// A new [`Sentence`]. + /// In addition, multiple tags following each slash (`'/'`) can be inserted to each token. + /// Tags can also be inserted at non-word boundaries, but such tags are ignored. /// /// # Errors /// - /// If the given `raw_text` is empty, an error variant will be returned. + /// This function will return an error variant when the text is empty, the length of the text + /// is even numbers, or the text contains invalid boundary characters. /// /// # Examples /// /// ``` /// use vaporetto::Sentence; /// - /// let s = Sentence::from_raw("How are you?"); - /// assert!(s.is_ok()); + /// let mut buf = String::new(); /// - /// let s = Sentence::from_raw(""); - /// assert!(s.is_err()); + /// let s = Sentence::from_partial_annotation( + /// "ま-ぁ|良-い|だ-ろ-う" + /// ).unwrap(); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("まぁ 良い だろう", buf); + /// + /// let s = Sentence::from_partial_annotation( + /// "ま-ぁ/名詞/マー|社-長/名詞/シャチョー|は/助詞/ワ|火-星 猫|だ/助動詞/ダ" + /// ).unwrap(); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ だ/助動詞/ダ", buf); + /// + /// let s = Sentence::from_partial_annotation( + /// "ま-ぁ/名詞/マー|社-長/名詞/シャチョー|は/助詞/ワ|火/名詞/ヒ-星|猫|だ/助動詞/ダ" + /// ).unwrap(); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ 火星 猫 だ/助動詞/ダ", buf); /// ``` - pub fn from_raw(raw_text: S) -> Result - where - S: Into, - { - let raw_text = raw_text.into(); - - let mut chars = Vec::with_capacity(0); - let mut boundaries = Vec::with_capacity(0); - let mut tags = Vec::with_capacity(0); - Self::parse_raw_text(&raw_text, &mut chars, &mut boundaries, &mut tags)?; - - Ok(Self::internal_new(raw_text, chars, boundaries, tags)) + pub fn from_partial_annotation(partial_annotation_text: &str) -> Result { + let mut text = String::new(); + let mut char_types = vec![]; + let mut boundaries = vec![]; + let mut str_to_char_pos = vec![]; + let mut char_to_str_pos = vec![]; + let mut tags = vec![]; + Self::parse_partial_annotation( + partial_annotation_text, + &mut text, + &mut char_types, + &mut boundaries, + &mut str_to_char_pos, + &mut char_to_str_pos, + &mut tags, + )?; + let n_tags = tags.len() / char_types.len(); + Ok(Self { + text: Cow::Owned(text), + char_types, + boundaries, + boundary_scores: vec![], + score_padding: 0, + char_pma_states: vec![], + type_pma_states: vec![], + predictor: None, + tags, + n_tags, + str_to_char_pos, + char_to_str_pos, + }) } - /// Updates the [`Sentence`] using a given string. + /// Updates the [`Sentence`] using a text with partial annotations. /// - /// # Arguments + /// Each character boundary must be annotated by the following rules: + /// - If the boundary is a token boundary, a pipe symbol (`'|'`) is inserted. + /// - If the boundary is not a token boundary, a dash symobl (`'-'`) is inserted. + /// - If the boundary is not annotated, a whitespace (`' '`) is inserted. /// - /// * `raw_text` - A raw string without any annotation. + /// In addition, multiple tags following each slash (`'/'`) can be inserted to each token. + /// Tags can also be inserted at non-word boundaries, but such tags are ignored. /// /// # Errors /// - /// If the given `raw_text` is empty, an error variant will be returned. - /// When an error is occurred, the sentence will be replaced with a white space. + /// This function will return an error variant when the text is empty, the length of the text + /// is even numbers, or the text contains invalid boundary characters. /// /// # Examples /// /// ``` /// use vaporetto::Sentence; /// - /// let mut s = Sentence::from_raw("How are you?").unwrap(); - /// s.update_raw("I am file.").unwrap(); - /// assert_eq!("I am file.", s.to_raw_string()); + /// let mut buf = String::new(); + /// let mut s = Sentence::default(); + /// + /// s.update_partial_annotation( + /// "ま-ぁ|良-い|だ-ろ-う" + /// ).unwrap(); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("まぁ 良い だろう", buf); + /// + /// s.update_partial_annotation( + /// "ま-ぁ/名詞/マー|社-長/名詞/シャチョー|は/助詞/ワ|火-星 猫|だ/助動詞/ダ" + /// ).unwrap(); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ だ/助動詞/ダ", buf); + /// + /// s.update_partial_annotation( + /// "ま-ぁ/名詞/マー|社-長/名詞/シャチョー|は/助詞/ワ|火/名詞/ヒ-星|猫|だ/助動詞/ダ" + /// ).unwrap(); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ 火星 猫 だ/助動詞/ダ", buf); /// ``` - pub fn update_raw(&mut self, raw_text: S) -> Result<()> - where - S: Into, - { - let raw_text = raw_text.into(); - - match Self::parse_raw_text( - &raw_text, - &mut self.chars, + pub fn update_partial_annotation(&mut self, partial_annotation_text: &str) -> Result<()> { + if let Err(e) = Self::parse_partial_annotation( + partial_annotation_text, + self.text.to_mut(), + &mut self.char_types, &mut self.boundaries, + &mut self.str_to_char_pos, + &mut self.char_to_str_pos, &mut self.tags, ) { - Ok(_) => { - self.text = raw_text; - self.update_common_info(); - Ok(()) - } - Err(e) => { - self.clear(); - Err(e) - } + self.set_default(); + return Err(e); } + self.boundary_scores.clear(); + self.score_padding = 0; + self.char_pma_states.clear(); + self.type_pma_states.clear(); + self.predictor.take(); + self.n_tags = self.tags.len() / self.char_types.len(); + Ok(()) } - /// Gets a string without any annotation. - /// - /// # Returns - /// - /// A reference to the string. + /// Gets a text without any annotation. /// /// # Examples /// /// ``` /// use vaporetto::Sentence; /// - /// let s = Sentence::from_raw("How are you?").unwrap(); - /// assert_eq!("How are you?", s.to_raw_string()); + /// let s = Sentence::from_tokenized("まぁ/副詞 良い/形容詞 だろう/助動詞").unwrap(); + /// assert_eq!("まぁ良いだろう", s.as_raw_text()); /// ``` - pub fn to_raw_string(&self) -> &str { + #[inline] + pub fn as_raw_text(&self) -> &str { &self.text } - /// Creates a new [`Sentence`] from a tokenized string. - /// - /// # Arguments - /// - /// * `tokenized_text` - A tokenized text that is annotated by the following rules: - /// - A whitespace (`' '`) is inserted to each token boundary. - /// - If necessary, a POS tag following a slash (`'/'`) can be added to each token. - /// - Each character following a back slash (`'\\'`) is escaped. - /// - /// # Returns - /// - /// A new [`Sentence`]. - /// - /// # Errors - /// - /// This function will return an error variant when: - /// - /// * `tokenized_text` is empty. - /// * `tokenized_text` starts/ends with a whitespace. - /// * `tokenized_text` contains consecutive whitespaces. + /// Returns an iterator of tokens. Tokens adjacent to [`CharacterBoundary::Unknown`] will be + /// skipped. /// /// # Examples /// /// ``` /// use vaporetto::Sentence; /// - /// let s = Sentence::from_tokenized("How are you?"); - /// assert!(s.is_ok()); - /// - /// let s = Sentence::from_tokenized("How/WRB are/VBP you?"); - /// assert!(s.is_ok()); - /// - /// let s = Sentence::from_tokenized("How are you?"); - /// assert!(s.is_err()); - /// ``` - pub fn from_tokenized(tokenized_text: S) -> Result - where - S: AsRef, - { - let tokenized_text = tokenized_text.as_ref(); - - let mut text = String::with_capacity(0); - let mut chars = Vec::with_capacity(0); - let mut boundaries = Vec::with_capacity(0); - let mut tags = Vec::with_capacity(0); - - Self::parse_tokenized_text( - tokenized_text, - &mut text, - &mut chars, - &mut boundaries, - &mut tags, - )?; - - Ok(Self::internal_new(text, chars, boundaries, tags)) - } - - /// Updates the [`Sentence`] using tokenized string. - /// - /// # Arguments - /// - /// * `tokenized_text` - A tokenized text that is annotated by the following rules: - /// - A whitespace (`' '`) is inserted to each token boundary. - /// - If necessary, a POS tag following a slash (`'/'`) can be added to each token. - /// - Each character following a back slash (`'\\'`) is escaped. - /// - /// # Errors - /// - /// This function will return an error variant when: - /// - /// * `tokenized_text` is empty. - /// * `tokenized_text` starts/ends with a whitespace. - /// * `tokenized_text` contains consecutive whitespaces. + /// let s = Sentence::from_partial_annotation("ま-ぁ|社-長|は|火-星 猫|だ").unwrap(); + /// let mut it = s.iter_tokens(); /// - /// When an error is occurred, the sentence will be replaced with a white space. - /// - /// # Examples + /// let token = it.next().unwrap(); + /// assert_eq!("まぁ", token.surface()); + /// assert_eq!(0, token.start()); + /// assert_eq!(2, token.end()); /// - /// ``` - /// use vaporetto::Sentence; + /// let token = it.next().unwrap(); + /// assert_eq!("社長", token.surface()); + /// assert_eq!(2, token.start()); + /// assert_eq!(4, token.end()); /// - /// let mut s = Sentence::from_tokenized("How are you?").unwrap(); + /// let token = it.next().unwrap(); + /// assert_eq!("は", token.surface()); + /// assert_eq!(4, token.start()); + /// assert_eq!(5, token.end()); /// - /// s.update_tokenized("I am fine").unwrap(); - /// assert_eq!("Iamfine", s.to_raw_string()); + /// let token = it.next().unwrap(); + /// assert_eq!("だ", token.surface()); + /// assert_eq!(8, token.start()); + /// assert_eq!(9, token.end()); /// - /// s.update_tokenized("How/WRB are/VBP you ?/.").unwrap(); - /// assert_eq!("Howareyou?", s.to_raw_string()); + /// assert!(it.next().is_none()); /// ``` - pub fn update_tokenized(&mut self, tokenized_text: S) -> Result<()> - where - S: AsRef, - { - let tokenized_text = tokenized_text.as_ref(); - - match Self::parse_tokenized_text( - tokenized_text, - &mut self.text, - &mut self.chars, - &mut self.boundaries, - &mut self.tags, - ) { - Ok(_) => { - self.update_common_info(); - Ok(()) - } - Err(e) => { - self.clear(); - Err(e) - } + pub const fn iter_tokens(&'a self) -> TokenIterator<'a, 'b> { + TokenIterator { + token: Token { + sentence: self, + start: 0, + end: 0, + }, } } - /// Generates a string with whitespaces for word boundaries. - /// - /// # Returns - /// - /// A newly allocated string containing whitespaces for word boundaries. - /// - /// # Errors - /// - /// If the sentence contains unknown boundary, an error variant will be returned. - /// - /// # Examples - /// - /// ``` - /// use vaporetto::Sentence; - /// - /// let s = Sentence::from_tokenized("How are you?").unwrap(); - /// assert_eq!("How are you?", s.to_tokenized_string().unwrap()); - /// - /// let s = Sentence::from_tokenized("How/WRB are/VBP you?").unwrap(); - /// assert_eq!("How/WRB are/VBP you?", s.to_tokenized_string().unwrap()); - /// ``` - pub fn to_tokenized_string(&self) -> Result { - let mut result = String::with_capacity(self.text.len() * 2 - 1); - self.write_tokenized_string(&mut result)?; - Ok(result) - } - - /// Writes a string with whitespaces for word boundaries. - /// - /// # Arguments - /// - /// * `buffer` - A string buffer. - /// - /// # Errors - /// - /// If the sentence contains unknown boundary, an error variant will be returned. + /// Writes a tokenized text. Tokens adjacent to [`CharacterBoundary::Unknown`] will be skipped. /// /// # Examples /// @@ -736,506 +824,450 @@ impl Sentence { /// /// let mut buf = String::new(); /// - /// let s = Sentence::from_tokenized("How are you?").unwrap(); - /// s.write_tokenized_string(&mut buf).unwrap(); - /// assert_eq!("How are you?", buf); - /// - /// let s = Sentence::from_tokenized("How/WRB are/VBP you?").unwrap(); - /// s.write_tokenized_string(&mut buf).unwrap(); - /// assert_eq!("How/WRB are/VBP you?", buf); + /// let s = Sentence::from_partial_annotation( + /// "ま-ぁ/名詞|社-長/名詞|は/助詞|火-星/名詞|猫/名詞|だ/助動詞" + /// ).unwrap(); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("まぁ/名詞 社長/名詞 は/助詞 火星/名詞 猫/名詞 だ/助動詞", buf); + /// + /// let s = Sentence::from_partial_annotation( + /// "ま-ぁ/名詞|社-長/名詞|は/助詞|火-星 猫|だ/助動詞" + /// ).unwrap(); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("まぁ/名詞 社長/名詞 は/助詞 だ/助動詞", buf); /// ``` - pub fn write_tokenized_string(&self, buffer: &mut String) -> Result<()> { - let mut chars_iter = self.text.chars(); - buffer.clear(); - let c = chars_iter.next().unwrap(); - match c { - '\\' | '/' | '&' | ' ' => buffer.push('\\'), - _ => (), - } - buffer.push(c); - for (i, (c, b)) in chars_iter.zip(&self.boundaries).enumerate() { - match b { - BoundaryType::WordBoundary => { - if let Some(tag) = self.tags.get(i).and_then(|x| x.as_ref()) { - buffer.push('/'); - buffer.push_str(tag); + pub fn write_tokenized_text(&self, buf: &mut String) { + buf.clear(); + // `buf` always consists of a valid UTF-8 sequence because + // `Token::surface` and `Token::tags` return values in `str`. + unsafe { + let buf = buf.as_mut_vec(); + for token in self.iter_tokens() { + if !buf.is_empty() { + buf.push(b' '); + } + for &b in token.surface().as_bytes() { + match b { + b' ' | b'\\' | b'/' => { + buf.push(b'\\'); + } + _ => (), } - buffer.push(' '); + buf.push(b); } - BoundaryType::NotWordBoundary => (), - BoundaryType::Unknown => { - return Err(VaporettoError::invalid_sentence( - "contains an unknown boundary", - )); + let ts = token.tags(); + for tag in &ts[..ts.iter().rposition(|x| x.is_some()).map_or(0, |x| x + 1)] { + buf.push(b'/'); + if let Some(tag) = tag { + for &b in tag.as_bytes() { + match b { + b' ' | b'\\' | b'/' => { + buf.push(b'\\'); + } + _ => (), + } + buf.push(b); + } + } } } - match c { - '\\' | '/' | '&' | ' ' => buffer.push('\\'), - _ => (), - } - buffer.push(c); } - if let Some(tag) = self.tags.last().and_then(|x| x.as_ref()) { - buffer.push('/'); - buffer.push_str(tag); - } - Ok(()) } - /// Generates a vector of tokens. - /// - /// # Returns + /// Writes a text with partial annotations. /// - /// A newly allocated vector of tokens. + /// # Examples /// - /// # Errors + /// ``` + /// use vaporetto::Sentence; /// - /// If the sentence contains unknown boundaries, an error variant will be returned. + /// let mut buf = String::new(); /// - /// # Examples + /// let s = Sentence::from_tokenized("まぁ 良い だろう").unwrap(); + /// s.write_partial_annotation_text(&mut buf); + /// assert_eq!("ま-ぁ|良-い|だ-ろ-う", buf); /// + /// let s = Sentence::from_tokenized( + /// "まぁ/副詞/マー 良い/形容詞/ヨイ だろう/助動詞/ダロー" + /// ).unwrap(); + /// s.write_partial_annotation_text(&mut buf); + /// assert_eq!("ま-ぁ/副詞/マー|良-い/形容詞/ヨイ|だ-ろ-う/助動詞/ダロー", buf); /// ``` - /// use vaporetto::{Sentence, Token}; - /// - /// let s = Sentence::from_tokenized("How are you ?").unwrap(); - /// assert_eq!(vec![ - /// Token { surface: "How", tag: None }, - /// Token { surface: "are", tag: None }, - /// Token { surface: "you", tag: None }, - /// Token { surface: "?", tag: None }, - /// ], s.to_tokenized_vec().unwrap()); - /// - /// let s = Sentence::from_tokenized("How/WRB are/VBP you/PRP ?/.").unwrap(); - /// assert_eq!(vec![ - /// Token { surface: "How", tag: Some("WRB") }, - /// Token { surface: "are", tag: Some("VBP") }, - /// Token { surface: "you", tag: Some("PRP") }, - /// Token { surface: "?", tag: Some(".") }, - /// ], s.to_tokenized_vec().unwrap()); - /// ``` - pub fn to_tokenized_vec(&self) -> Result> { - let mut result = vec![]; - let mut start = 0; - for (i, b) in self.boundaries.iter().enumerate() { - match b { - BoundaryType::WordBoundary => { - let end = unsafe { *self.char_to_str_pos.get_unchecked(i + 1) }; - let surface = unsafe { self.text.get_unchecked(start..end) }; - let tag = self - .tags - .get(i) - .and_then(|x| x.as_ref()) - .map(|x| x.as_str()); - result.push(Token { surface, tag }); - start = end; + pub fn write_partial_annotation_text(&self, buf: &mut String) { + buf.clear(); + let mut char_iter = self.text.chars(); + buf.push(char_iter.next().unwrap()); + if self.n_tags != 0 { + let mut tag_iter = self.tags.chunks_exact(self.n_tags); + let ts = tag_iter.next().unwrap(); + for tag in &ts[..ts.iter().rposition(|x| x.is_some()).map_or(0, |x| x + 1)] { + buf.push('/'); + if let Some(tag) = tag { + buf.push_str(tag); } - BoundaryType::NotWordBoundary => (), - BoundaryType::Unknown => { - return Err(VaporettoError::invalid_sentence( - "contains an unknown boundary", - )); + } + for ((c, ts), &b) in char_iter.zip(tag_iter).zip(&self.boundaries) { + buf.push(match b { + CharacterBoundary::NotWordBoundary => '-', + CharacterBoundary::WordBoundary => '|', + CharacterBoundary::Unknown => ' ', + }); + buf.push(c); + for tag in &ts[..ts.iter().rposition(|x| x.is_some()).map_or(0, |x| x + 1)] { + buf.push('/'); + if let Some(tag) = tag { + buf.push_str(tag); + } } } + } else { + for (c, &b) in char_iter.zip(&self.boundaries) { + buf.push(match b { + CharacterBoundary::NotWordBoundary => '-', + CharacterBoundary::WordBoundary => '|', + CharacterBoundary::Unknown => ' ', + }); + buf.push(c); + } } - let surface = unsafe { self.text.get_unchecked(start..) }; - let tag = self - .tags - .last() - .and_then(|x| x.as_ref()) - .map(|x| x.as_str()); - result.push(Token { surface, tag }); - Ok(result) } - /// Creates a new [`Sentence`] from a string with partial annotations. - /// - /// # Arguments - /// - /// * `labeled_text` - A partially annotated text. Each character boundary is annotated by the following rules: - /// - If the boundary is a token boundary, a pipe symbol (`'|'`) is inserted. - /// - If the boundary is not a token boundary, a dash symobl (`'-'`) is inserted. - /// - If the boundary is not annotated, a whitespace (`' '`) is inserted. - /// - /// In addition, a POS tag following a slash (`'/'`) can be inserted to each token. - /// - /// # Returns - /// - /// A new [`Sentence`]. - /// - /// # Errors - /// - /// This function will return an error variant when: - /// - /// * `labeled_text` is empty. - /// * The length of `lsbeled_text` is even numbers. - /// * `labeled_text` contains invalid boundary characters. + /// Removes tag information and updates the number of tags. /// /// # Examples /// /// ``` /// use vaporetto::Sentence; /// - /// let s = Sentence::from_partial_annotation("g-o-o-d|i-d e-a"); - /// assert!(s.is_ok()); - /// - /// let s = Sentence::from_partial_annotation("I-t/PRP|'-s/VBZ|o-k-a-y/JJ|./."); - /// assert!(s.is_ok()); - /// - /// let s = Sentence::from_partial_annotation("b-a-d/i-d-e-a"); - /// assert!(s.is_err()); + /// let mut s = Sentence::from_tokenized("火星/名詞/カセー に 行き/動詞 まし/助動詞/マシ た").unwrap(); + /// let mut buf = String::new(); + /// assert_eq!(2, s.n_tags()); + /// assert_eq!(16, s.tags().len()); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("火星/名詞/カセー に 行き/動詞 まし/助動詞/マシ た", buf); + /// + /// s.reset_tags(1); + /// assert_eq!(1, s.n_tags()); + /// assert_eq!(8, s.tags().len()); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("火星 に 行き まし た", buf); /// ``` - pub fn from_partial_annotation(labeled_text: S) -> Result - where - S: AsRef, - { - let labeled_text = labeled_text.as_ref(); - - let mut text = String::with_capacity(0); - let mut chars = Vec::with_capacity(0); - let mut boundaries = Vec::with_capacity(0); - let mut tags = Vec::with_capacity(0); - Self::parse_partial_annotation( - labeled_text, - &mut text, - &mut chars, - &mut boundaries, - &mut tags, - )?; - - Ok(Self::internal_new(text, chars, boundaries, tags)) + #[inline] + pub fn reset_tags(&mut self, n_tags: usize) { + self.tags.clear(); + self.tags.resize(n_tags * self.len(), None); + self.n_tags = n_tags; } - /// Updates the [`Sentence`] using a string with partial annotations. - /// - /// # Arguments - /// - /// * `labeled_text` - A partially annotated text. Each character boundary is annotated by the following rules: - /// - If the boundary is a token boundary, a pipe symbol (`'|'`) is inserted. - /// - If the boundary is not a token boundary, a dash symobl (`'-'`) is inserted. - /// - If the boundary is not annotated, a whitespace (`' '`) is inserted. - /// - /// In addition, a POS tag following a slash (`'/'`) can be inserted to each token. - /// - /// # Errors - /// - /// This function will return an error variant when: - /// - /// * `labeled_text` is empty. - /// * The length of `lsbeled_text` is even numbers. - /// * `labeled_text` contains invalid boundary characters. - /// - /// When an error is occurred, the sentence will be replaced with a white space. + /// Returns a slice of character types. /// /// # Examples /// /// ``` - /// use vaporetto::Sentence; - /// - /// let mut s = Sentence::from_raw("g-o-o-d|i-d e-a").unwrap(); - /// s.update_partial_annotation("h-e-l-l-o").unwrap(); - /// assert_eq!("hello", s.to_raw_string()); + /// use vaporetto::{CharacterType, Sentence}; /// - /// s.update_partial_annotation("I-t/PRP|'-s/VBZ|o-k-a-y/JJ|./.").unwrap(); - /// assert_eq!("It'sokay.", s.to_raw_string()); + /// let s = Sentence::from_tokenized("火星 に 行き まし た").unwrap(); + /// assert_eq!(&[ + /// CharacterType::Kanji as u8, + /// CharacterType::Kanji as u8, + /// CharacterType::Hiragana as u8, + /// CharacterType::Kanji as u8, + /// CharacterType::Hiragana as u8, + /// CharacterType::Hiragana as u8, + /// CharacterType::Hiragana as u8, + /// CharacterType::Hiragana as u8, + /// ], s.char_types()); /// ``` - pub fn update_partial_annotation(&mut self, labeled_text: S) -> Result<()> - where - S: AsRef, - { - let labeled_text = labeled_text.as_ref(); - - match Self::parse_partial_annotation( - labeled_text, - &mut self.text, - &mut self.chars, - &mut self.boundaries, - &mut self.tags, - ) { - Ok(_) => { - self.update_common_info(); - Ok(()) - } - Err(e) => { - self.clear(); - Err(e) - } - } + #[inline] + pub fn char_types(&self) -> &[u8] { + &self.char_types } - /// Generates a string with partial annotations. - /// - /// # Returns - /// - /// A newly allocated string with partial annotations. + /// Returns a slice of boundary types. /// /// # Examples /// /// ``` - /// use vaporetto::Sentence; + /// use vaporetto::{CharacterBoundary, Sentence}; /// - /// let s = Sentence::from_tokenized("How are you ?").unwrap(); - /// assert_eq!("H-o-w|a-r-e|y-o-u|?", &s.to_partial_annotation_string()); - /// - /// let s = Sentence::from_tokenized("How/WRB are you/PRP ?").unwrap(); - /// assert_eq!("H-o-w/WRB|a-r-e|y-o-u/PRP|?", &s.to_partial_annotation_string()); + /// let s = Sentence::from_partial_annotation("火-星|に|行-き|ま-し た").unwrap(); + /// assert_eq!(&[ + /// CharacterBoundary::NotWordBoundary, + /// CharacterBoundary::WordBoundary, + /// CharacterBoundary::WordBoundary, + /// CharacterBoundary::NotWordBoundary, + /// CharacterBoundary::WordBoundary, + /// CharacterBoundary::NotWordBoundary, + /// CharacterBoundary::Unknown, + /// ], s.boundaries()); /// ``` - pub fn to_partial_annotation_string(&self) -> String { - let mut result = String::with_capacity(self.text.len() * 2 - 1); - self.write_partial_annotation_string(&mut result); - result + #[inline] + pub fn boundaries(&self) -> &[CharacterBoundary] { + &self.boundaries } - /// Write a string with partial annotations. - /// - /// # Arguments - /// - /// * `buffer` - A string buffer. - /// - /// A newly allocated string with partial annotations. + /// Returns a mutable slice of boundary types. /// /// # Examples /// /// ``` - /// use vaporetto::Sentence; + /// use vaporetto::{CharacterBoundary, Sentence}; /// + /// let mut s = Sentence::from_partial_annotation("火-星|に|行-き|ま-し た").unwrap(); + /// s.boundaries_mut()[6] = CharacterBoundary::WordBoundary; /// let mut buf = String::new(); - /// - /// let s = Sentence::from_tokenized("How are you ?").unwrap(); - /// s.write_partial_annotation_string(&mut buf); - /// assert_eq!("H-o-w|a-r-e|y-o-u|?", buf); - /// - /// let s = Sentence::from_tokenized("How/WRB are you/PRP ?").unwrap(); - /// s.write_partial_annotation_string(&mut buf); - /// assert_eq!("H-o-w/WRB|a-r-e|y-o-u/PRP|?", buf); + /// s.write_partial_annotation_text(&mut buf); + /// assert_eq!("火-星|に|行-き|ま-し|た", buf); /// ``` - pub fn write_partial_annotation_string(&self, buffer: &mut String) { - let mut chars_iter = self.text.chars(); - buffer.clear(); - buffer.push(chars_iter.next().unwrap()); - for (i, (c, b)) in chars_iter.zip(&self.boundaries).enumerate() { - match b { - BoundaryType::WordBoundary => { - if let Some(tag) = self.tags.get(i).and_then(|x| x.as_ref()) { - buffer.push('/'); - buffer.push_str(tag); - } - buffer.push('|'); - } - BoundaryType::NotWordBoundary => { - buffer.push('-'); - } - BoundaryType::Unknown => { - buffer.push(' '); - } - } - buffer.push(c); - } - if let Some(tag) = self.tags.last().and_then(|x| x.as_ref()) { - buffer.push('/'); - buffer.push_str(tag); + #[inline] + pub fn boundaries_mut(&mut self) -> &mut [CharacterBoundary] { + &mut self.boundaries + } + + /// Returns a slice of boundary scores. + #[inline] + pub fn boundary_scores(&self) -> &[i32] { + if self.boundary_scores.is_empty() { + &[] + } else { + &self.boundary_scores[self.score_padding..self.score_padding + self.boundaries.len()] } } - /// Gets a reference to the boundary information. + /// Returns a reference to the internal representation of tags. /// - /// # Returns - /// - /// A reference to the boundary information. + /// In the representation, tags are stored in an array, and + /// the `j`-th tag of the `i`-th character is stored in the `i*k+j`-th element, + /// where `k` is the maximum number of tags (i.e., [`Sentence::n_tags()`]). /// /// # Examples /// /// ``` - /// use vaporetto::{BoundaryType, Sentence}; + /// use vaporetto::Sentence; /// - /// let s = Sentence::from_partial_annotation("a|b-c d").unwrap(); - /// assert_eq!(&[ - /// BoundaryType::WordBoundary, - /// BoundaryType::NotWordBoundary, - /// BoundaryType::Unknown, - /// ], s.boundaries()); + /// let mut s = Sentence::from_tokenized("火星/名詞/カセー に 行き/動詞 まし/助動詞/マシ た").unwrap(); + /// assert_eq!(16, s.tags().len()); + /// assert_eq!("名詞", s.tags()[2].as_ref().unwrap().as_ref()); + /// assert_eq!("カセー", s.tags()[3].as_ref().unwrap().as_ref()); + /// assert_eq!("動詞", s.tags()[8].as_ref().unwrap().as_ref()); + /// assert_eq!("助動詞", s.tags()[12].as_ref().unwrap().as_ref()); + /// assert_eq!("マシ", s.tags()[13].as_ref().unwrap().as_ref()); /// ``` - pub fn boundaries(&self) -> &[BoundaryType] { - &self.boundaries - } - - /// Gets a mutable reference to the boundary information. - /// - /// # Returns - /// - /// A mutable reference to the boundary information. - pub fn boundaries_mut(&mut self) -> &mut [BoundaryType] { - &mut self.boundaries + #[inline] + pub fn tags(&self) -> &[Option>] { + &self.tags } - /// Gets a reference to the part-of-speech information. - /// - /// Each tag is placed at the last of the corresponding token. For example, when the first token - /// containing three characters has a tag, that tag will be placed at the third element of the - /// returned slice. + /// Returns a mutable reference to the internal representation of tags. /// - /// # Returns + /// In the representation, tags are stored in an array, and + /// the `j`-th tag of the `i`-th character is stored in the `i*k+j`-th element, + /// where `k` is the maximum number of tags (i.e., [`Sentence::n_tags()`]). /// - /// A reference to the POS information. + /// Tags can also be inserted at other positions, but such tags are ignored. /// /// # Examples /// /// ``` - /// use std::sync::Arc; + /// use vaporetto::Sentence; /// - /// use vaporetto::{BoundaryType, Sentence}; + /// let mut buf = String::new(); /// - /// let s = Sentence::from_tokenized("I/PRP am a/DT cat/NN ./.").unwrap(); - /// assert_eq!(&[ - /// Some(Arc::new("PRP".to_string())), // 'I' - /// None, // 'a' - /// None, // 'm' - /// Some(Arc::new("DT".to_string())), // 'a' - /// None, // 'c' - /// None, // 'a' - /// Some(Arc::new("NN".to_string())), // 't' - /// Some(Arc::new(".".to_string())), // '.' - /// ], s.tags()); + /// let mut s = Sentence::from_tokenized("火星/名詞/カセー に 行き/動詞 まし/助動詞/マシ た").unwrap(); + /// s.tags_mut()[4].replace("助詞".into()); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("火星/名詞/カセー に/助詞 行き/動詞 まし/助動詞/マシ た", buf); + /// + /// // Sets a pronunciation of the first character (`火`), but this character is not the last + /// // of a word. + /// s.tags_mut()[1].replace("ヒ".into()); + /// s.write_tokenized_text(&mut buf); + /// assert_eq!("火星/名詞/カセー に/助詞 行き/動詞 まし/助動詞/マシ た", buf); /// ``` - pub fn tags(&self) -> &[Option>] { - &self.tags + #[inline] + pub fn tags_mut(&mut self) -> &mut [Option>] { + &mut self.tags } - /// Gets a mutable reference to the part-of-speech information. + /// Update the tag information. + /// If you want to predict tags, call this function after calling [`Predictor::predict()`] and + /// word boundaries are fixed. /// - /// # Returns + /// # Panics /// - /// A mutable reference to the part-of-speech information. - pub fn tags_mut(&mut self) -> &mut [Option>] { - &mut self.tags + /// The predictor must be created with `predict_tags = true`. + /// + #[cfg_attr( + feature = "std", + doc = " +# Examples + +``` +use std::fs::File; + +use vaporetto::{Model, Predictor, Sentence}; + +let f = File::open(\"../resources/model.bin\").unwrap(); +let model = Model::read(f).unwrap(); +let predictor = Predictor::new(model, true).unwrap(); + +let mut s = Sentence::from_raw(\"まぁ良いだろう\").unwrap(); +predictor.predict(&mut s); +let mut buf = String::new(); +s.write_tokenized_text(&mut buf); +assert_eq!(\"まぁ 良い だろう\", buf); + +s.fill_tags(); + +s.write_tokenized_text(&mut buf); +assert_eq!( + \"まぁ/副詞/マー 良い/形容詞/ヨイ だろう/助動詞/ダロー\", + buf, +); +``` +" + )] + #[cfg(feature = "tag-prediction")] + #[cfg_attr(docsrs, doc(cfg(feature = "tag-prediction")))] + #[inline] + pub fn fill_tags(&mut self) { + if let Some(p) = self.predictor.as_ref() { + p.predict_tags(self); + } } - /// Gets a reference to the characters. - /// - /// # Returns - /// - /// A reference to the characters. + /// Returns the maximum number of tags. /// /// # Examples /// /// ``` /// use vaporetto::Sentence; /// - /// let s = Sentence::from_raw("A1あエ漢?").unwrap(); - /// assert_eq!(&['A', '1', 'あ', 'エ', '漢', '?'], s.chars()); + /// let s = Sentence::from_tokenized("火星/名詞/カセー に 行き/動詞 まし/助動詞/マシ た").unwrap(); + /// assert_eq!(2, s.n_tags()); /// ``` - pub fn chars(&self) -> &[char] { - &self.chars + #[inline] + pub const fn n_tags(&self) -> usize { + self.n_tags } - /// Gets immutable references to the characters and character types, and a mutable reference to - /// boundaries. - /// - /// # Returns - /// - /// A tuple of references. - /// - /// # Examples - /// - /// ``` - /// use vaporetto::{BoundaryType, CharacterType, Sentence}; - /// - /// let mut s = Sentence::from_partial_annotation("A-1|あ エ-漢|?").unwrap(); - /// let (chars, char_types, boundaries) = s.chars_and_boundaries_mut(); - /// assert_eq!(&['A', '1', 'あ', 'エ', '漢', '?'], chars); - /// assert_eq!(&[ - /// CharacterType::Roman as u8, - /// CharacterType::Digit as u8, - /// CharacterType::Hiragana as u8, - /// CharacterType::Katakana as u8, - /// CharacterType::Kanji as u8, - /// CharacterType::Other as u8, - /// ], char_types); - /// assert_eq!(&[ - /// BoundaryType::NotWordBoundary, - /// BoundaryType::WordBoundary, - /// BoundaryType::Unknown, - /// BoundaryType::NotWordBoundary, - /// BoundaryType::WordBoundary, - /// ], boundaries); - /// ``` - pub fn chars_and_boundaries_mut(&mut self) -> (&[char], &[u8], &mut [BoundaryType]) { - (&self.chars, &self.char_type, &mut self.boundaries) + #[inline] + pub(crate) fn len(&self) -> usize { + self.char_types.len() } - /// Gets a reference to the character type information. - /// - /// # Returns - /// - /// A reference to the character type information. - /// - /// # Examples - /// - /// ``` - /// use vaporetto::{CharacterType, Sentence}; - /// - /// let s = Sentence::from_raw("A1あエ漢?").unwrap(); - /// assert_eq!(&[ - /// CharacterType::Roman as u8, - /// CharacterType::Digit as u8, - /// CharacterType::Hiragana as u8, - /// CharacterType::Katakana as u8, - /// CharacterType::Kanji as u8, - /// CharacterType::Other as u8, - /// ], s.char_types()); - /// ``` - pub fn char_types(&self) -> &[u8] { - &self.char_type + #[inline] + pub(crate) fn set_predictor(&mut self, predictor: &'b Predictor) { + self.predictor.replace(predictor); } - /// Gets a reference to the boundary score information. - /// - /// # Returns + /// # Safety /// - /// If the predictor inserted, the boundary score information is returned. Otherwise, None. - pub fn boundary_scores(&self) -> &[i32] { - &self.boundary_scores + /// `pos` must be a position corresponding to a boundary in the UTF-8 format. + #[inline(always)] + pub(crate) unsafe fn str_to_char_pos(&self, pos: usize) -> usize { + *self.str_to_char_pos.get_unchecked(pos) } - /// Gets a character position in the code point unit. - /// - /// # Returns - /// - /// A position in the code point unit. - /// - /// # Errors - /// - /// `index` must be a valid position. - pub fn get_char_pos(&self, index: usize) -> Result { - if index == 0 { - Ok(0) - } else { - match self.str_to_char_pos.get(index) { - Some(index) if *index != 0 => Ok(*index), - _ => Err(VaporettoError::invalid_argument("index", "invalid index")), - } - } + #[inline] + pub(crate) fn text_substring(&self, start: usize, end: usize) -> &str { + &self.text[self.char_to_str_pos[start]..self.char_to_str_pos[end]] } - #[cfg(feature = "train")] - pub(crate) fn char_substring(&self, start: usize, end: usize) -> &str { - let begin = self.char_to_str_pos[start]; - let end = self.char_to_str_pos[end]; - &self.text.as_str()[begin..end] + #[cfg(test)] + pub(crate) fn char_to_str_pos(&self) -> &[usize] { + &self.char_to_str_pos } } -#[cfg(test)] -mod tests { - use super::*; - use BoundaryType::*; - use CharacterType::*; +/// A Token information. +#[derive(Clone, Copy)] +pub struct Token<'a, 'b> { + sentence: &'a Sentence<'a, 'b>, + start: usize, + end: usize, +} - #[test] - fn test_sentence_from_raw_empty() { - let s = Sentence::from_raw(""); +impl<'a, 'b> Token<'a, 'b> { + /// Returns the surface of this token. + #[inline] + pub fn surface(&self) -> &'a str { + self.sentence.text_substring(self.start, self.end) + } - assert_eq!( - "InvalidArgumentError: raw_text: must contain at least one character", + /// Returns tags of this token. + #[inline] + pub fn tags(&self) -> &'a [Option>] { + let start = (self.end - 1) * self.sentence.n_tags(); + let end = self.end * self.sentence.n_tags(); + &self.sentence.tags[start..end] + } + + /// Returns the start position of this token in characters. + #[inline] + pub const fn start(&self) -> usize { + self.start + } + + /// Returns the end position of this token in characters. + #[inline] + pub const fn end(&self) -> usize { + self.end + } +} + +/// Iterator returned by [`Sentence::iter_tokens()`]. +pub struct TokenIterator<'a, 'b> { + token: Token<'a, 'b>, +} + +impl<'a, 'b> Iterator for TokenIterator<'a, 'b> { + type Item = Token<'a, 'b>; + + fn next(&mut self) -> Option { + self.token.start = self.token.end; + if let Some(boundaries) = self.token.sentence.boundaries().get(self.token.start..) { + let mut skip_token = false; + for (i, &b) in boundaries.iter().enumerate() { + if b == CharacterBoundary::WordBoundary { + if skip_token { + self.token.start += i + 1; + skip_token = false; + } else { + self.token.end += i + 1; + return Some(self.token); + } + } else if b == CharacterBoundary::Unknown { + skip_token = true; + } + } + if skip_token { + self.token.end = self.token.sentence.boundaries().len() + 1; + return None; + } + } else { + return None; + } + self.token.end = self.token.sentence.boundaries().len() + 1; + Some(self.token) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use alloc::string::ToString; + + use CharacterBoundary::*; + use CharacterType::*; + + #[test] + fn test_sentence_from_raw_empty() { + let s = Sentence::from_raw(""); + + assert_eq!( + "InvalidArgumentError: text: must contain at least one character", &s.err().unwrap().to_string() ); } @@ -1246,22 +1278,16 @@ mod tests { let result = s.update_raw(""); assert_eq!( - "InvalidArgumentError: raw_text: must contain at least one character", + "InvalidArgumentError: text: must contain at least one character", &result.err().unwrap().to_string() ); - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!(" ", s.as_raw_text()); + assert_eq!(&[0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 1], s.char_to_str_pos()); + assert_eq!([Other as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1269,7 +1295,7 @@ mod tests { let s = Sentence::from_raw("A1あ\0ア亜"); assert_eq!( - "InvalidArgumentError: raw_text: must not contain NULL", + "InvalidArgumentError: text: must not contain NULL", &s.err().unwrap().to_string() ); } @@ -1280,40 +1306,28 @@ mod tests { let result = s.update_raw("A1あ\0ア亜"); assert_eq!( - "InvalidArgumentError: raw_text: must not contain NULL", + "InvalidArgumentError: text: must not contain NULL", &result.err().unwrap().to_string() ); - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!(" ", s.as_raw_text()); + assert_eq!(&[0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 1], s.char_to_str_pos()); + assert_eq!([Other as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] fn test_sentence_from_raw_one() { - let s = Sentence::from_raw("あ"); - - let expected = Sentence { - text: "あ".to_string(), - chars: vec!['あ'], - str_to_char_pos: vec![0, 0, 0, 1], - char_to_str_pos: vec![0, 3], - char_type: vec![Hiragana as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s.unwrap()); + let s = Sentence::from_raw("あ").unwrap(); + + assert_eq!("あ", s.as_raw_text()); + assert_eq!(&[0, 0, 0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 3], s.char_to_str_pos()); + assert_eq!([Hiragana as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1321,38 +1335,32 @@ mod tests { let mut s = Sentence::from_raw("12345").unwrap(); s.update_raw("あ").unwrap(); - let expected = Sentence { - text: "あ".to_string(), - chars: vec!['あ'], - str_to_char_pos: vec![0, 0, 0, 1], - char_to_str_pos: vec![0, 3], - char_type: vec![Hiragana as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!("あ", s.as_raw_text()); + assert_eq!(&[0, 0, 0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 3], s.char_to_str_pos()); + assert_eq!([Hiragana as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] fn test_sentence_from_raw() { - let s = Sentence::from_raw("Rustで良いプログラミング体験を!"); + let s = Sentence::from_raw("Rustで良いプログラミング体験を!").unwrap(); - let expected = Sentence { - text: "Rustで良いプログラミング体験を!".to_string(), - chars: vec![ - 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', - '体', '験', 'を', '!', - ], - str_to_char_pos: vec![ + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, ], - char_to_str_pos: vec![ - 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, - ], - char_type: vec![ + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ Roman as u8, Roman as u8, Roman as u8, @@ -1372,12 +1380,10 @@ mod tests { Hiragana as u8, Other as u8, ], - boundaries: vec![Unknown; 17], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 18], - }; - assert_eq!(expected, s.unwrap()); + s.char_types() + ); + assert_eq!([Unknown; 17], s.boundaries()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1385,20 +1391,20 @@ mod tests { let mut s = Sentence::from_raw("12345").unwrap(); s.update_raw("Rustで良いプログラミング体験を!").unwrap(); - let expected = Sentence { - text: "Rustで良いプログラミング体験を!".to_string(), - chars: vec![ - 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', - '体', '験', 'を', '!', - ], - str_to_char_pos: vec![ + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, ], - char_to_str_pos: vec![ - 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, - ], - char_type: vec![ + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ Roman as u8, Roman as u8, Roman as u8, @@ -1418,22 +1424,10 @@ mod tests { Hiragana as u8, Other as u8, ], - boundaries: vec![Unknown; 17], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 18], - }; - assert_eq!(expected, s); - } - - #[test] - fn test_sentence_to_raw() { - let s = Sentence::from_raw("Rustで良いプログラミング体験を!"); - - assert_eq!( - "Rustで良いプログラミング体験を!", - s.unwrap().to_raw_string() + s.char_types() ); + assert_eq!([Unknown; 17], s.boundaries()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1456,18 +1450,12 @@ mod tests { &result.err().unwrap().to_string() ); - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!(" ", s.as_raw_text()); + assert_eq!(&[0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 1], s.char_to_str_pos()); + assert_eq!([Other as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1490,18 +1478,12 @@ mod tests { &result.err().unwrap().to_string() ); - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!(" ", s.as_raw_text()); + assert_eq!(&[0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 1], s.char_to_str_pos()); + assert_eq!([Other as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1524,18 +1506,12 @@ mod tests { &result.err().unwrap().to_string() ); - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!(" ", s.as_raw_text()); + assert_eq!(&[0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 1], s.char_to_str_pos()); + assert_eq!([Other as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1558,18 +1534,12 @@ mod tests { &result.err().unwrap().to_string() ); - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!(" ", s.as_raw_text()); + assert_eq!(&[0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 1], s.char_to_str_pos()); + assert_eq!([Other as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1592,36 +1562,24 @@ mod tests { &result.err().unwrap().to_string() ); - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!(" ", s.as_raw_text()); + assert_eq!(&[0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 1], s.char_to_str_pos()); + assert_eq!([Other as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] fn test_sentence_from_tokenized_one() { - let s = Sentence::from_tokenized("あ"); - - let expected = Sentence { - text: "あ".to_string(), - chars: vec!['あ'], - str_to_char_pos: vec![0, 0, 0, 1], - char_to_str_pos: vec![0, 3], - char_type: vec![Hiragana as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s.unwrap()); + let s = Sentence::from_tokenized("あ").unwrap(); + + assert_eq!("あ", s.as_raw_text()); + assert_eq!(&[0, 0, 0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 3], s.char_to_str_pos()); + assert_eq!([Hiragana as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -1629,38 +1587,164 @@ mod tests { let mut s = Sentence::from_raw("12345").unwrap(); s.update_tokenized("あ").unwrap(); - let expected = Sentence { - text: "あ".to_string(), - chars: vec!['あ'], - str_to_char_pos: vec![0, 0, 0, 1], - char_to_str_pos: vec![0, 3], - char_type: vec![Hiragana as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!("あ", s.as_raw_text()); + assert_eq!(&[0, 0, 0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 3], s.char_to_str_pos()); + assert_eq!([Hiragana as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] fn test_sentence_from_tokenized() { - let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !"); + let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !").unwrap(); - let expected = Sentence { - text: "Rustで良いプログラミング体験を!".to_string(), - chars: vec![ - 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', - '体', '験', 'を', '!', + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, ], - str_to_char_pos: vec![ + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ + Roman as u8, + Roman as u8, + Roman as u8, + Roman as u8, + Hiragana as u8, + Kanji as u8, + Hiragana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Kanji as u8, + Kanji as u8, + Hiragana as u8, + Other as u8, + ], + s.char_types() + ); + assert_eq!( + [ + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + ], + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); + } + + #[test] + fn test_sentence_update_tokenized() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("Rust で 良い プログラミング 体験 を !") + .unwrap(); + + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, ], - char_to_str_pos: vec![ - 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ + Roman as u8, + Roman as u8, + Roman as u8, + Roman as u8, + Hiragana as u8, + Kanji as u8, + Hiragana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Kanji as u8, + Kanji as u8, + Hiragana as u8, + Other as u8, + ], + s.char_types() + ); + assert_eq!( + [ + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, ], - char_type: vec![ + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); + } + + #[test] + fn test_sentence_from_tokenized_with_tags() { + let s = + Sentence::from_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号") + .unwrap(); + + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ Roman as u8, Roman as u8, Roman as u8, @@ -1680,7 +1764,10 @@ mod tests { Hiragana as u8, Other as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, NotWordBoundary, NotWordBoundary, @@ -1699,44 +1786,251 @@ mod tests { WordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 18], - }; - assert_eq!(expected, s.unwrap()); + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); + assert_eq!( + &[ + None, + None, + None, + Some(Cow::Borrowed("名詞")), + None, + None, + Some(Cow::Borrowed("形容詞")), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + Some(Cow::Borrowed("補助記号")), + ], + s.tags.as_slice() + ); + } + + #[test] + fn test_sentence_update_tokenized_with_tags() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号") + .unwrap(); + + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ + Roman as u8, + Roman as u8, + Roman as u8, + Roman as u8, + Hiragana as u8, + Kanji as u8, + Hiragana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Kanji as u8, + Kanji as u8, + Hiragana as u8, + Other as u8, + ], + s.char_types() + ); + assert_eq!( + [ + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + ], + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); + assert_eq!( + &[ + None, + None, + None, + Some(Cow::Borrowed("名詞")), + None, + None, + Some(Cow::Borrowed("形容詞")), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + Some(Cow::Borrowed("補助記号")), + ], + s.tags.as_slice() + ); } #[test] fn test_sentence_from_tokenized_with_tags_two_slashes() { let s = Sentence::from_tokenized( - "Rust/名詞 で 良い/形容詞/動詞 プログラミング 体験 を !/補助記号", - ); + "Rust/名詞 で 良い/形容詞/イイ プログラミング 体験 を !/補助記号", + ) + .unwrap(); + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); assert_eq!( - "InvalidArgumentError: tokenized_text: invalid slash found", - &s.err().unwrap().to_string() + &[ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ + Roman as u8, + Roman as u8, + Roman as u8, + Roman as u8, + Hiragana as u8, + Kanji as u8, + Hiragana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Katakana as u8, + Kanji as u8, + Kanji as u8, + Hiragana as u8, + Other as u8, + ], + s.char_types() + ); + assert_eq!( + [ + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + ], + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); + assert_eq!( + &[ + None, + None, + None, + None, + None, + None, + Some(Cow::Borrowed("名詞")), + None, + None, + None, + None, + None, + Some(Cow::Borrowed("形容詞")), + Some(Cow::Borrowed("イイ")), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + Some(Cow::Borrowed("補助記号")), + None, + ], + s.tags.as_slice() ); } #[test] - fn test_sentence_from_tokenized_with_tags() { - let s = - Sentence::from_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号"); + fn test_sentence_update_tokenized_two_slashes() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("Rust/名詞 で 良い/形容詞/イイ プログラミング 体験 を !/補助記号") + .unwrap(); - let expected = Sentence { - text: "Rustで良いプログラミング体験を!".to_string(), - chars: vec![ - 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', - '体', '験', 'を', '!', - ], - str_to_char_pos: vec![ + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, ], - char_to_str_pos: vec![ - 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, - ], - char_type: vec![ + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ Roman as u8, Roman as u8, Roman as u8, @@ -1756,7 +2050,10 @@ mod tests { Hiragana as u8, Other as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, NotWordBoundary, NotWordBoundary, @@ -1775,16 +2072,33 @@ mod tests { WordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![ + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); + assert_eq!( + &[ + None, + None, + None, + None, + None, + None, + Some(Cow::Borrowed("名詞")), + None, + None, + None, + None, + None, + Some(Cow::Borrowed("形容詞")), + Some(Cow::Borrowed("イイ")), + None, + None, + None, None, None, None, - Some(Arc::new("名詞".to_string())), None, None, - Some(Arc::new("形容詞".to_string())), None, None, None, @@ -1795,32 +2109,36 @@ mod tests { None, None, None, - Some(Arc::new("補助記号".to_string())), + None, + None, + Some(Cow::Borrowed("補助記号")), + None, ], - }; - assert_eq!(expected, s.unwrap()); + s.tags.as_slice() + ); } #[test] - fn test_sentence_update_tokenized() { - let mut s = Sentence::from_raw("12345").unwrap(); - s.update_tokenized("Rust で 良い プログラミング 体験 を !") - .unwrap(); + fn test_sentence_from_tokenized_with_tags_empty_slashes() { + let s = Sentence::from_tokenized( + "Rust//ラスト で 良い/形容詞/イイ プログラミング 体験 を !//ビックリ", + ) + .unwrap(); - let expected = Sentence { - text: "Rustで良いプログラミング体験を!".to_string(), - chars: vec![ - 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', - '体', '験', 'を', '!', - ], - str_to_char_pos: vec![ + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, ], - char_to_str_pos: vec![ - 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, - ], - char_type: vec![ + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ Roman as u8, Roman as u8, Roman as u8, @@ -1840,7 +2158,10 @@ mod tests { Hiragana as u8, Other as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, NotWordBoundary, NotWordBoundary, @@ -1859,58 +2180,72 @@ mod tests { WordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 18], - }; - assert_eq!(expected, s); - } - - #[test] - fn test_sentence_update_tokenized_two_slashes() { - let mut s = Sentence::from_raw("12345").unwrap(); - let result = - s.update_tokenized("Rust/名詞 で 良い/形容詞/動詞 プログラミング 体験 を !/補助記号"); - + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); assert_eq!( - "InvalidArgumentError: tokenized_text: invalid slash found", - &result.err().unwrap().to_string() + &[ + None, + None, + None, + None, + None, + None, + None, + Some(Cow::Borrowed("ラスト")), + None, + None, + None, + None, + Some(Cow::Borrowed("形容詞")), + Some(Cow::Borrowed("イイ")), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + Some(Cow::Borrowed("ビックリ")), + ], + s.tags.as_slice() ); - - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); } #[test] - fn test_sentence_update_tokenized_with_tags() { + fn test_sentence_update_tokenized_empty_slashes() { let mut s = Sentence::from_raw("12345").unwrap(); - s.update_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号") + s.update_tokenized("Rust//ラスト で 良い/形容詞/イイ プログラミング 体験 を !//ビックリ") .unwrap(); - let expected = Sentence { - text: "Rustで良いプログラミング体験を!".to_string(), - chars: vec![ - 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', - '体', '験', 'を', '!', - ], - str_to_char_pos: vec![ + assert_eq!("Rustで良いプログラミング体験を!", s.as_raw_text()); + assert_eq!( + &[ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, ], - char_to_str_pos: vec![ - 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, - ], - char_type: vec![ + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,], + s.char_to_str_pos() + ); + assert_eq!( + [ Roman as u8, Roman as u8, Roman as u8, @@ -1930,7 +2265,10 @@ mod tests { Hiragana as u8, Other as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, NotWordBoundary, NotWordBoundary, @@ -1949,16 +2287,35 @@ mod tests { WordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![ + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); + assert_eq!( + &[ + None, + None, + None, + None, + None, + None, + None, + Some(Cow::Borrowed("ラスト")), + None, + None, + None, + None, + Some(Cow::Borrowed("形容詞")), + Some(Cow::Borrowed("イイ")), + None, + None, + None, + None, + None, None, None, None, - Some(Arc::new("名詞".to_string())), None, None, - Some(Arc::new("形容詞".to_string())), None, None, None, @@ -1969,30 +2326,31 @@ mod tests { None, None, None, - Some(Arc::new("補助記号".to_string())), + None, + Some(Cow::Borrowed("ビックリ")), ], - }; - assert_eq!(expected, s); + s.tags.as_slice() + ); } #[test] fn test_sentence_from_tokenized_with_escape_whitespace() { let s = Sentence::from_tokenized("火星 猫 の 生態 ( M \\ et\\ al. )").unwrap(); - let expected = Sentence { - text: "火星猫の生態(M et al.)".to_string(), - chars: vec![ - '火', '星', '猫', 'の', '生', '態', '(', 'M', ' ', 'e', 't', ' ', 'a', 'l', '.', - ')', - ], - str_to_char_pos: vec![ + assert_eq!("火星猫の生態(M et al.)", s.as_raw_text()); + assert_eq!( + &[ 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ], - char_to_str_pos: vec![ - 0, 3, 6, 9, 12, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - ], - char_type: vec![ + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 3, 6, 9, 12, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28], + s.char_to_str_pos() + ); + assert_eq!( + [ Kanji as u8, Kanji as u8, Kanji as u8, @@ -2010,7 +2368,10 @@ mod tests { Other as u8, Other as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, WordBoundary, WordBoundary, @@ -2027,11 +2388,9 @@ mod tests { NotWordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 16], - }; - assert_eq!(expected, s); + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -2040,20 +2399,20 @@ mod tests { s.update_tokenized("火星 猫 の 生態 ( M \\ et\\ al. )") .unwrap(); - let expected = Sentence { - text: "火星猫の生態(M et al.)".to_string(), - chars: vec![ - '火', '星', '猫', 'の', '生', '態', '(', 'M', ' ', 'e', 't', ' ', 'a', 'l', '.', - ')', - ], - str_to_char_pos: vec![ + assert_eq!("火星猫の生態(M et al.)", s.as_raw_text()); + assert_eq!( + &[ 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ], - char_to_str_pos: vec![ - 0, 3, 6, 9, 12, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - ], - char_type: vec![ + s.str_to_char_pos.as_slice(), + ); + assert_eq!( + [0, 3, 6, 9, 12, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28], + s.char_to_str_pos() + ); + assert_eq!( + [ Kanji as u8, Kanji as u8, Kanji as u8, @@ -2071,7 +2430,10 @@ mod tests { Other as u8, Other as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, WordBoundary, WordBoundary, @@ -2088,25 +2450,23 @@ mod tests { NotWordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 16], - }; - assert_eq!(expected, s); + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); } #[test] fn test_sentence_from_tokenized_with_escape_backslash() { - let s = Sentence::from_tokenized("改行 に \\\\n を 用い る"); + let s = Sentence::from_tokenized("改行 に \\\\n を 用い る").unwrap(); - let expected = Sentence { - text: "改行に\\nを用いる".to_string(), - chars: vec!['改', '行', 'に', '\\', 'n', 'を', '用', 'い', 'る'], - str_to_char_pos: vec![ - 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, - ], - char_to_str_pos: vec![0, 3, 6, 9, 10, 11, 14, 17, 20, 23], - char_type: vec![ + assert_eq!("改行に\\nを用いる", s.as_raw_text()); + assert_eq!( + &[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9], + s.str_to_char_pos.as_slice(), + ); + assert_eq!([0, 3, 6, 9, 10, 11, 14, 17, 20, 23], s.char_to_str_pos()); + assert_eq!( + [ Kanji as u8, Kanji as u8, Hiragana as u8, @@ -2117,7 +2477,10 @@ mod tests { Hiragana as u8, Hiragana as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, WordBoundary, WordBoundary, @@ -2127,11 +2490,9 @@ mod tests { NotWordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 9], - }; - assert_eq!(expected, s.unwrap()); + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -2139,14 +2500,14 @@ mod tests { let mut s = Sentence::from_raw("12345").unwrap(); s.update_tokenized("改行 に \\\\n を 用い る").unwrap(); - let expected = Sentence { - text: "改行に\\nを用いる".to_string(), - chars: vec!['改', '行', 'に', '\\', 'n', 'を', '用', 'い', 'る'], - str_to_char_pos: vec![ - 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, - ], - char_to_str_pos: vec![0, 3, 6, 9, 10, 11, 14, 17, 20, 23], - char_type: vec![ + assert_eq!("改行に\\nを用いる", s.as_raw_text()); + assert_eq!( + &[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9], + s.str_to_char_pos.as_slice(), + ); + assert_eq!([0, 3, 6, 9, 10, 11, 14, 17, 20, 23], s.char_to_str_pos()); + assert_eq!( + [ Kanji as u8, Kanji as u8, Hiragana as u8, @@ -2157,7 +2518,10 @@ mod tests { Hiragana as u8, Hiragana as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, WordBoundary, WordBoundary, @@ -2167,25 +2531,23 @@ mod tests { NotWordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 9], - }; - assert_eq!(expected, s); + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); } #[test] fn test_sentence_from_tokenized_escape_slash() { - let s = Sentence::from_tokenized("品詞 に \\/ を 用い る"); + let s = Sentence::from_tokenized("品詞 に \\/ を 用い る").unwrap(); - let expected = Sentence { - text: "品詞に/を用いる".to_string(), - chars: vec!['品', '詞', 'に', '/', 'を', '用', 'い', 'る'], - str_to_char_pos: vec![ - 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, - ], - char_to_str_pos: vec![0, 3, 6, 9, 10, 13, 16, 19, 22], - char_type: vec![ + assert_eq!("品詞に/を用いる", s.as_raw_text()); + assert_eq!( + &[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8], + s.str_to_char_pos.as_slice(), + ); + assert_eq!([0, 3, 6, 9, 10, 13, 16, 19, 22], s.char_to_str_pos()); + assert_eq!( + [ Kanji as u8, Kanji as u8, Hiragana as u8, @@ -2195,7 +2557,10 @@ mod tests { Hiragana as u8, Hiragana as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, WordBoundary, WordBoundary, @@ -2204,11 +2569,9 @@ mod tests { NotWordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 8], - }; - assert_eq!(expected, s.unwrap()); + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -2216,14 +2579,14 @@ mod tests { let mut s = Sentence::from_raw("12345").unwrap(); s.update_tokenized("品詞 に \\/ を 用い る").unwrap(); - let expected = Sentence { - text: "品詞に/を用いる".to_string(), - chars: vec!['品', '詞', 'に', '/', 'を', '用', 'い', 'る'], - str_to_char_pos: vec![ - 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, - ], - char_to_str_pos: vec![0, 3, 6, 9, 10, 13, 16, 19, 22], - char_type: vec![ + assert_eq!("品詞に/を用いる", s.as_raw_text()); + assert_eq!( + &[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8], + s.str_to_char_pos.as_slice(), + ); + assert_eq!([0, 3, 6, 9, 10, 13, 16, 19, 22], s.char_to_str_pos()); + assert_eq!( + [ Kanji as u8, Kanji as u8, Hiragana as u8, @@ -2233,7 +2596,10 @@ mod tests { Hiragana as u8, Hiragana as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, WordBoundary, WordBoundary, @@ -2242,103 +2608,93 @@ mod tests { NotWordBoundary, WordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 8], - }; - assert_eq!(expected, s); + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); } #[test] fn test_sentence_to_tokenized_string_unknown() { - let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); - let result = s.unwrap().to_tokenized_string(); + let s = Sentence::from_partial_annotation("火-星 猫|の|生-態").unwrap(); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); - assert_eq!( - "InvalidSentenceError: contains an unknown boundary", - result.err().unwrap().to_string() - ); + assert_eq!("の 生態", buf); } #[test] fn test_sentence_to_tokenized_string() { - let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !"); + let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !").unwrap(); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); - assert_eq!( - "Rust で 良い プログラミング 体験 を !", - s.unwrap().to_tokenized_string().unwrap() - ); + assert_eq!("Rust で 良い プログラミング 体験 を !", buf); } #[test] fn test_sentence_to_tokenized_string_with_tags() { let s = - Sentence::from_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号"); + Sentence::from_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号") + .unwrap(); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); assert_eq!( "Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号", - s.unwrap().to_tokenized_string().unwrap() + buf, ); } #[test] fn test_sentence_to_tokenized_string_escape() { - let s = Sentence::from_partial_annotation("火-星-猫|の| |生-態|\\-n"); + let s = Sentence::from_partial_annotation("火-星-猫|の| |生-態|\\-n").unwrap(); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); - assert_eq!( - "火星猫 の \\ 生態 \\\\n", - s.unwrap().to_tokenized_string().unwrap() - ); + assert_eq!("火星猫 の \\ 生態 \\\\n", buf); } #[test] fn test_sentence_to_tokenized_vec_unknown() { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態").unwrap(); - let result = s.to_tokenized_vec(); + let mut it = s.iter_tokens(); - assert_eq!( - "InvalidSentenceError: contains an unknown boundary", - result.err().unwrap().to_string() - ); + let token = it.next().unwrap(); + assert_eq!("の", token.surface()); + + let token = it.next().unwrap(); + assert_eq!("生態", token.surface()); + + assert!(it.next().is_none()); } #[test] fn test_sentence_to_tokenized_vec() { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !").unwrap(); + let mut it = s.iter_tokens(); - assert_eq!( - vec![ - Token { - surface: "Rust", - tag: None - }, - Token { - surface: "で", - tag: None - }, - Token { - surface: "良い", - tag: None - }, - Token { - surface: "プログラミング", - tag: None - }, - Token { - surface: "体験", - tag: None - }, - Token { - surface: "を", - tag: None - }, - Token { - surface: "!", - tag: None - }, - ], - s.to_tokenized_vec().unwrap() - ); + let token = it.next().unwrap(); + assert_eq!("Rust", token.surface()); + + let token = it.next().unwrap(); + assert_eq!("で", token.surface()); + + let token = it.next().unwrap(); + assert_eq!("良い", token.surface()); + + let token = it.next().unwrap(); + assert_eq!("プログラミング", token.surface()); + + let token = it.next().unwrap(); + assert_eq!("体験", token.surface()); + + let token = it.next().unwrap(); + assert_eq!("を", token.surface()); + + let token = it.next().unwrap(); + assert_eq!("!", token.surface()); + + assert!(it.next().is_none()); } #[test] @@ -2346,40 +2702,37 @@ mod tests { let s = Sentence::from_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号") .unwrap(); + let mut it = s.iter_tokens(); - assert_eq!( - vec![ - Token { - surface: "Rust", - tag: Some("名詞"), - }, - Token { - surface: "で", - tag: None, - }, - Token { - surface: "良い", - tag: Some("形容詞"), - }, - Token { - surface: "プログラミング", - tag: None, - }, - Token { - surface: "体験", - tag: None, - }, - Token { - surface: "を", - tag: None, - }, - Token { - surface: "!", - tag: Some("補助記号"), - }, - ], - s.to_tokenized_vec().unwrap() - ); + let token = it.next().unwrap(); + assert_eq!("Rust", token.surface()); + assert_eq!("名詞", token.tags()[0].as_ref().unwrap()); + + let token = it.next().unwrap(); + assert_eq!("で", token.surface()); + assert!(token.tags()[0].is_none()); + + let token = it.next().unwrap(); + assert_eq!("良い", token.surface()); + assert_eq!("形容詞", token.tags()[0].as_ref().unwrap()); + + let token = it.next().unwrap(); + assert_eq!("プログラミング", token.surface()); + assert!(token.tags()[0].is_none()); + + let token = it.next().unwrap(); + assert_eq!("体験", token.surface()); + assert!(token.tags()[0].is_none()); + + let token = it.next().unwrap(); + assert_eq!("を", token.surface()); + assert!(token.tags()[0].is_none()); + + let token = it.next().unwrap(); + assert_eq!("!", token.surface()); + assert_eq!("補助記号", token.tags()[0].as_ref().unwrap()); + + assert!(it.next().is_none()); } #[test] @@ -2387,7 +2740,7 @@ mod tests { let s = Sentence::from_partial_annotation(""); assert_eq!( - "InvalidArgumentError: labeled_text: must contain at least one character", + "InvalidArgumentError: partial_annotation_text: must contain at least one character", &s.err().unwrap().to_string() ); } @@ -2398,22 +2751,16 @@ mod tests { let result = s.update_partial_annotation(""); assert_eq!( - "InvalidArgumentError: labeled_text: must contain at least one character", + "InvalidArgumentError: partial_annotation_text: must contain at least one character", &result.err().unwrap().to_string() ); - let expected = Sentence { - text: " ".to_string(), - chars: vec![' '], - str_to_char_pos: vec![0, 1], - char_to_str_pos: vec![0, 1], - char_type: vec![Other as u8], - boundaries: vec![], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None], - }; - assert_eq!(expected, s); + assert_eq!(" ", s.as_raw_text()); + assert_eq!(&[0, 1], s.str_to_char_pos.as_slice()); + assert_eq!([0, 1], s.char_to_str_pos()); + assert_eq!([Other as u8], s.char_types()); + assert!(s.boundaries().is_empty()); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -2421,7 +2768,7 @@ mod tests { let s = Sentence::from_partial_annotation("A-1-あ-\0-ア-亜"); assert_eq!( - "InvalidArgumentError: labeled_text: must not contain NULL", + "InvalidArgumentError: partial_annotation_text: must not contain NULL", &s.err().unwrap().to_string() ); } @@ -2432,7 +2779,7 @@ mod tests { let result = s.update_partial_annotation("A-1-あ-\0-ア-亜"); assert_eq!( - "InvalidArgumentError: labeled_text: must not contain NULL", + "InvalidArgumentError: partial_annotation_text: must not contain NULL", &result.err().unwrap().to_string() ); } @@ -2442,7 +2789,7 @@ mod tests { let result = Sentence::from_partial_annotation("火-星 猫|の|生-態 "); assert_eq!( - "InvalidArgumentError: labeled_text: invalid annotation", + "InvalidArgumentError: partial_annotation_text: invalid annotation", &result.err().unwrap().to_string() ); } @@ -2453,7 +2800,7 @@ mod tests { let result = s.update_partial_annotation("火-星 猫|の|生-態 "); assert_eq!( - "InvalidArgumentError: labeled_text: invalid annotation", + "InvalidArgumentError: partial_annotation_text: invalid annotation", &result.err().unwrap().to_string() ); } @@ -2463,7 +2810,7 @@ mod tests { let s = Sentence::from_partial_annotation("火-星?猫|の|生-態"); assert_eq!( - "InvalidArgumentError: labeled_text: contains an invalid boundary character: '?'", + "InvalidArgumentError: partial_annotation_text: contains an invalid boundary character: '?'", &s.err().unwrap().to_string() ); } @@ -2474,21 +2821,23 @@ mod tests { let result = s.update_partial_annotation("火-星?猫|の|生-態"); assert_eq!( - "InvalidArgumentError: labeled_text: contains an invalid boundary character: '?'", + "InvalidArgumentError: partial_annotation_text: contains an invalid boundary character: '?'", &result.err().unwrap().to_string() ); } #[test] fn test_sentence_from_partial_annotation_one() { - let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); - - let expected = Sentence { - text: "火星猫の生態".to_string(), - chars: vec!['火', '星', '猫', 'の', '生', '態'], - str_to_char_pos: vec![0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6], - char_to_str_pos: vec![0, 3, 6, 9, 12, 15, 18], - char_type: vec![ + let s = Sentence::from_partial_annotation("火-星 猫|の|生-態").unwrap(); + + assert_eq!("火星猫の生態", s.as_raw_text()); + assert_eq!( + &[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6], + s.str_to_char_pos.as_slice(), + ); + assert_eq!([0, 3, 6, 9, 12, 15, 18], s.char_to_str_pos()); + assert_eq!( + [ Kanji as u8, Kanji as u8, Kanji as u8, @@ -2496,18 +2845,19 @@ mod tests { Kanji as u8, Kanji as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, Unknown, WordBoundary, WordBoundary, NotWordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 6], - }; - assert_eq!(expected, s.unwrap()); + s.boundaries() + ); + assert!(s.boundary_scores().is_empty()); } #[test] @@ -2515,12 +2865,14 @@ mod tests { let mut s = Sentence::from_raw("12345").unwrap(); s.update_partial_annotation("火-星 猫|の|生-態").unwrap(); - let expected = Sentence { - text: "火星猫の生態".to_string(), - chars: vec!['火', '星', '猫', 'の', '生', '態'], - str_to_char_pos: vec![0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6], - char_to_str_pos: vec![0, 3, 6, 9, 12, 15, 18], - char_type: vec![ + assert_eq!("火星猫の生態", s.as_raw_text()); + assert_eq!( + &[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6], + s.str_to_char_pos.as_slice(), + ); + assert_eq!([0, 3, 6, 9, 12, 15, 18], s.char_to_str_pos()); + assert_eq!( + [ Kanji as u8, Kanji as u8, Kanji as u8, @@ -2528,37 +2880,18 @@ mod tests { Kanji as u8, Kanji as u8, ], - boundaries: vec![ + s.char_types() + ); + assert_eq!( + [ NotWordBoundary, Unknown, WordBoundary, WordBoundary, NotWordBoundary, ], - boundary_scores: vec![], - tag_scores: TagScores::default(), - tags: vec![None; 6], - }; - assert_eq!(expected, s); - } - - #[test] - fn test_sentence_to_partial_annotation_string() { - let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); - - assert_eq!( - "火-星 猫|の|生-態", - s.unwrap().to_partial_annotation_string() - ); - } - - #[test] - fn test_sentence_to_partial_annotation_string_with_tags() { - let s = Sentence::from_partial_annotation("火-星 猫|の/助詞|生-態/名詞"); - - assert_eq!( - "火-星 猫|の/助詞|生-態/名詞", - s.unwrap().to_partial_annotation_string() + s.boundaries() ); + assert!(s.boundary_scores().is_empty()); } } diff --git a/vaporetto/src/tag_model.rs b/vaporetto/src/tag_model.rs deleted file mode 100644 index a90ea795..00000000 --- a/vaporetto/src/tag_model.rs +++ /dev/null @@ -1,28 +0,0 @@ -use alloc::string::String; -use alloc::vec::Vec; - -use bincode::{Decode, Encode}; - -use crate::ngram_model::NgramModel; - -#[derive(Decode, Encode)] -pub struct TagClassInfo { - pub(crate) name: String, - pub(crate) bias: i32, -} - -// Left and right weight arrays of the TagModel are ordered as follows: -// -// tok1 tok2 tok3 ... -// -// tag1 1 5 9 -// tag2 2 6 . -// tag3 3 7 . -// ... 4 8 . -#[derive(Default, Decode, Encode)] -pub struct TagModel { - pub(crate) class_info: Vec, - pub(crate) left_char_model: NgramModel, - pub(crate) right_char_model: NgramModel, - pub(crate) self_char_model: NgramModel, -} diff --git a/vaporetto/src/tag_trainer.rs b/vaporetto/src/tag_trainer.rs index 8e58343f..f8ce5863 100644 --- a/vaporetto/src/tag_trainer.rs +++ b/vaporetto/src/tag_trainer.rs @@ -1,203 +1,316 @@ +use alloc::borrow::Cow; use alloc::collections::BTreeMap; +use alloc::string::ToString; +use hashbrown::HashMap; use liblinear::LibLinearModel; use crate::errors::{Result, VaporettoError}; -use crate::feature::{StringNgramFeature, TagExampleGenerator, TagFeature}; -use crate::ngram_model::{NgramData, NgramModel}; +use crate::model::TagModel; +use crate::ngram_model::{TagNgramData, TagNgramModel, TagWeight}; use crate::sentence::Sentence; -use crate::tag_model::{TagClassInfo, TagModel}; -use crate::trainer::{Indexer, SolverType, QUANTIZE_BIT_DEPTH}; +use crate::trainer::{NgramFeature, SolverType}; + +use crate::trainer::QUANTIZE_BIT_DEPTH; + +#[derive(Debug, Eq, Hash, PartialEq)] +enum TagFeature<'a> { + CharacterNgram(NgramFeature<&'a str>), + CharacterTypeNgram(NgramFeature<&'a [u8]>), +} + +impl<'a> TagFeature<'a> { + pub const fn char_ngram(ngram: &'a str, rel_position: isize) -> Self { + Self::CharacterNgram(NgramFeature { + ngram, + rel_position, + }) + } + + pub const fn type_ngram(ngram: &'a [u8], rel_position: isize) -> Self { + Self::CharacterTypeNgram(NgramFeature { + ngram, + rel_position, + }) + } +} + +#[derive(Debug)] +struct TagExample<'a> { + tags: &'a [Option>], + features: Vec>, +} pub struct TagTrainer<'a> { - example_generator: TagExampleGenerator, - char_window_size: u8, - feature_ids: Indexer>, - tag_ids: Indexer, - xs: Vec>, - ys: Vec, + _char_window_size: u8, + char_ngram_size: u8, + _type_window_size: u8, + type_ngram_size: u8, + // Uses BTreeMap to improve compression ratio. + examples: BTreeMap<&'a str, Vec>>, } impl<'a> TagTrainer<'a> { - pub fn new(char_ngram_size: u8, char_window_size: u8) -> Self { + pub fn new( + char_window_size: u8, + char_ngram_size: u8, + type_window_size: u8, + type_ngram_size: u8, + ) -> Self { Self { - example_generator: TagExampleGenerator::new(char_ngram_size, char_window_size), - char_window_size, - feature_ids: Indexer::new(), - tag_ids: Indexer::new(), - xs: vec![], - ys: vec![], + _char_window_size: char_window_size, + char_ngram_size, + _type_window_size: type_window_size, + type_ngram_size, + examples: BTreeMap::new(), } } - pub fn push_sentence(&mut self, s: &'a Sentence) -> Result<()> { - let examples = self.example_generator.generate(s)?; - for example in examples { - let mut feature_ids = BTreeMap::new(); - for f in &example.features { - let fid = self.feature_ids.get_id(f); - *feature_ids - .entry((fid + 1).try_into().unwrap()) - .or_insert(0.0) += 1.0; + pub fn add_example<'b>(&mut self, sentence: &'a Sentence<'a, 'b>) { + for token in sentence.iter_tokens() { + if token.tags().is_empty() { + continue; + } + let mut features = vec![]; + let token_len = token.end() - token.start(); + for n in 0..usize::from(self.char_ngram_size) { + let ngram_len = token_len + n + 1; + for i in token.end().saturating_sub(ngram_len) + ..(token.start() + 1).min(sentence.len().saturating_sub(ngram_len - 1)) + { + features.push(TagFeature::char_ngram( + sentence.text_substring(i, i + ngram_len), + isize::try_from(i + ngram_len - token.end()).unwrap(), + )); + } + } + for n in 0..usize::from(self.type_ngram_size) { + let ngram_len = token_len + n + 1; + for i in token.end().saturating_sub(ngram_len) + ..(token.start() + 1).min(sentence.len().saturating_sub(ngram_len - 1)) + { + features.push(TagFeature::type_ngram( + &sentence.char_types()[i..i + ngram_len], + isize::try_from(i + ngram_len - token.end()).unwrap(), + )); + } } - self.xs.push(feature_ids.into_iter().collect()); - self.ys - .push(self.tag_ids.get_id(example.tag.as_str()) as f64); + self.examples + .entry(token.surface()) + .or_insert_with(Vec::new) + .push(TagExample { + tags: token.tags(), + features, + }); } - Ok(()) } - pub fn n_features(&self) -> usize { - self.feature_ids.len() + #[allow(clippy::type_complexity)] + fn gen_feature_vecs<'b>( + examples: &'b [TagExample<'a>], + idx: usize, + tag_ids: &HashMap<&'a str, usize>, + ) -> ( + HashMap<&'b TagFeature<'a>, u32>, + Vec>, + Vec, + ) { + let mut feature_ids = HashMap::new(); + let mut xs = vec![]; + let mut ys = vec![]; + for example in examples { + if let Some(tag) = example.tags[idx].as_ref() { + ys.push(tag_ids[tag.as_ref()] as f64) + } else { + continue; + } + let mut feature_vec = vec![]; + for feature in &example.features { + let new_id = u32::try_from(feature_ids.len() + 1).unwrap(); + let feature_id = *feature_ids.entry(feature).or_insert(new_id); + feature_vec.push((feature_id, 1f64)); + } + xs.push(feature_vec); + } + (feature_ids, xs, ys) } - pub fn train(self, epsilon: f64, cost: f64, solver: SolverType) -> Result { - if self.xs.is_empty() { - // Returns an empty model if there is no training data. - return Ok(TagModel::default()); + fn train_tag( + token: String, + examples: &[TagExample<'a>], + epsilon: f64, + cost: f64, + solver: SolverType, + ) -> Result { + let n_tags = examples.iter().fold(0, |acc, x| acc.max(x.tags.len())); + let mut tag_ids = vec![HashMap::new(); n_tags]; + let mut tags = vec![vec![]; n_tags]; + for example in examples { + for ((tag, tag_ids), tags) in example.tags.iter().zip(&mut tag_ids).zip(&mut tags) { + if let Some(tag) = tag { + if !tag_ids.contains_key(tag.as_ref()) { + let new_id = tag_ids.len(); + tag_ids.insert(tag.as_ref(), new_id); + tags.push(tag.to_string()); + } + } + } } + let n_class = tags + .iter() + .fold(0, |acc, x| acc + if x.len() >= 2 { x.len() } else { 0 }); - let mut builder = liblinear::Builder::new(); - let training_input = liblinear::util::TrainingInput::from_sparse_features(self.ys, self.xs) - .map_err(|e| VaporettoError::invalid_model(format!("liblinear error: {:?}", e)))?; - builder.problem().input_data(training_input).bias(1.0); - builder - .parameters() - .solver_type(solver.into()) - .stopping_criterion(epsilon) - .constraints_violation_cost(cost); - let model = builder - .build_model() - .map_err(|e| VaporettoError::invalid_model(e.to_string()))?; + let mut bias = vec![0; n_class]; // Uses BTreeMap to increase compression ratio. - let mut left_char_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); - let mut right_char_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); - let mut self_char_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); - - let mut weight_max = 0.; - for i in 0..i32::try_from(self.tag_ids.len())? { - let weight = model.label_bias(i).abs(); - if weight > weight_max { - weight_max = weight; - } - for fid in 0..i32::try_from(model.num_features())? { - let weight = model.feature_coefficient(fid, i).abs(); - if weight > weight_max { - weight_max = weight; - } - } - } - let quantize_multiplier = weight_max / f64::from((1 << (QUANTIZE_BIT_DEPTH - 1)) - 1); - if quantize_multiplier == 0. { - return Err(VaporettoError::invalid_model("all weights are zero")); - } + let mut char_ngram_weights = BTreeMap::new(); + let mut type_ngram_weights = BTreeMap::new(); - let mut class_info = vec![]; + let mut class_offset = 0; + for (i, tag_ids) in tag_ids.iter().enumerate() { + if tag_ids.len() <= 1 { + // fixed tag + continue; + } - for i in 0..self.tag_ids.len() { - class_info.push(TagClassInfo { - name: self.tag_ids.keys()[model.labels()[i] as usize].clone(), - bias: unsafe { - (model.label_bias(i32::try_from(i)?) / quantize_multiplier).to_int_unchecked() - }, - }); + // train + let (feature_ids, xs, ys) = Self::gen_feature_vecs(examples, i, tag_ids); - for (fid, feature) in self.feature_ids.keys().iter().enumerate() { - let raw_weight = - model.feature_coefficient(i32::try_from(fid + 1)?, i32::try_from(i)?); - let weight = - unsafe { (raw_weight / quantize_multiplier).to_int_unchecked::() }; + let mut builder = liblinear::Builder::new(); + let training_input = liblinear::util::TrainingInput::from_sparse_features(ys, xs) + .map_err(|e| VaporettoError::invalid_model(format!("liblinear error: {:?}", e)))?; + builder.problem().input_data(training_input).bias(1.0); + builder + .parameters() + .solver_type(solver.into()) + .stopping_criterion(epsilon) + .constraints_violation_cost(cost); + let model = builder + .build_model() + .map_err(|e| VaporettoError::invalid_model(e.to_string()))?; - if weight == 0 { - continue; + // Calculates the quantize multiplier + let mut weight_max = 1e-6f64; + for i in 0..i32::try_from(tag_ids.len()).unwrap() { + let bias = model.label_bias(i).abs(); + weight_max = weight_max.max(bias); + for fid in 0..model.num_features() { + let weight = model.feature_coefficient(i32::try_from(fid + 1)?, i).abs(); + weight_max = weight_max.max(weight); } + } + let quantize_multiplier = weight_max / f64::from((1 << (QUANTIZE_BIT_DEPTH - 1)) - 1); + for (i, &cls) in model.labels().iter().enumerate() { + bias[class_offset + usize::try_from(cls).unwrap()] = unsafe { + (model.label_bias(i32::try_from(i).unwrap()) / quantize_multiplier) + .to_int_unchecked::() + }; + } + for (feature, fid) in feature_ids { match feature { - TagFeature::LeftCharacterNgram(StringNgramFeature { - rel_position, + TagFeature::CharacterNgram(NgramFeature { ngram, - }) => { - let pos = -rel_position - 1; - let idx = i + pos as usize * self.tag_ids.len(); - if let Some(weights) = left_char_weights.get_mut(*ngram) { - weights[idx] = weight; - } else { - let mut weights = - vec![0; usize::from(self.char_window_size) * self.tag_ids.len()]; - weights[idx] = weight; - left_char_weights.insert(ngram.to_string(), weights); - } - } - TagFeature::LeftCharacterNgramBos(StringNgramFeature { - rel_position, - ngram, - }) => { - let pos = -rel_position - 1; - let idx = i + pos as usize * self.tag_ids.len(); - let ngram = "\0".to_string() + *ngram; - left_char_weights.entry(ngram).or_insert_with(|| { - vec![0; usize::from(self.char_window_size) * self.tag_ids.len()] - })[idx] = weight; - } - TagFeature::RightCharacterNgram(StringNgramFeature { rel_position, - ngram, }) => { - let pos = isize::from(self.char_window_size) - rel_position; - let idx = i + pos as usize * self.tag_ids.len(); - if let Some(weights) = right_char_weights.get_mut(*ngram) { - weights[idx] = weight; - } else { - let mut weights = - vec![0; usize::from(self.char_window_size) * self.tag_ids.len()]; - weights[idx] = weight; - right_char_weights.insert(ngram.to_string(), weights); + for (i, &cls) in model.labels().iter().enumerate() { + let raw_weight = model.feature_coefficient( + i32::try_from(fid)?, + i32::try_from(i).unwrap(), + ); + let weight = unsafe { + (raw_weight / quantize_multiplier).to_int_unchecked::() + }; + if weight == 0 { + continue; + } + char_ngram_weights + .entry((*ngram, u8::try_from(*rel_position).unwrap())) + .or_insert_with(|| vec![0; n_class]) + [class_offset + usize::try_from(cls).unwrap()] = weight; } } - TagFeature::RightCharacterNgramEos(StringNgramFeature { - rel_position, + TagFeature::CharacterTypeNgram(NgramFeature { ngram, + rel_position, }) => { - let pos = isize::from(self.char_window_size) - rel_position; - let idx = i + pos as usize * self.tag_ids.len(); - let ngram = ngram.to_string() + "\0"; - right_char_weights.entry(ngram).or_insert_with(|| { - vec![0; usize::from(self.char_window_size) * self.tag_ids.len()] - })[idx] = weight; - } - TagFeature::Character(ngram) => { - if let Some(weights) = self_char_weights.get_mut(*ngram) { - weights[i] = weight; - } else { - let mut weights = vec![0; self.tag_ids.len()]; - weights[i] = weight; - self_char_weights.insert(ngram.to_string(), weights); + for (i, &cls) in model.labels().iter().enumerate() { + let raw_weight = model.feature_coefficient( + i32::try_from(fid)?, + i32::try_from(i).unwrap(), + ); + let weight = unsafe { + (raw_weight / quantize_multiplier).to_int_unchecked::() + }; + if weight == 0 { + continue; + } + type_ngram_weights + .entry((*ngram, u8::try_from(*rel_position).unwrap())) + .or_insert_with(|| vec![0; n_class]) + [class_offset + usize::try_from(cls).unwrap()] = weight; } } - }; + } } + class_offset += tag_ids.len(); + } + + let mut char_ngram_model = BTreeMap::new(); + for ((ngram, rel_position), weights) in char_ngram_weights { + char_ngram_model + .entry(ngram.to_string()) + .or_insert_with(Vec::new) + .push(TagWeight { + rel_position, + weights, + }); + } + let mut type_ngram_model = BTreeMap::new(); + for ((ngram, rel_position), weights) in type_ngram_weights { + type_ngram_model + .entry(ngram.to_vec()) + .or_insert_with(Vec::new) + .push(TagWeight { + rel_position, + weights, + }); } Ok(TagModel { - class_info, - left_char_model: NgramModel { - data: left_char_weights - .into_iter() - .map(|(ngram, weights)| NgramData { ngram, weights }) - .collect(), - }, - right_char_model: NgramModel { - data: right_char_weights + token, + tags, + char_ngram_model: TagNgramModel( + char_ngram_model .into_iter() - .map(|(ngram, weights)| NgramData { ngram, weights }) + .map(|(ngram, weights)| TagNgramData { ngram, weights }) .collect(), - }, - self_char_model: NgramModel { - data: self_char_weights + ), + type_ngram_model: TagNgramModel( + type_ngram_model .into_iter() - .map(|(ngram, weights)| NgramData { ngram, weights }) + .map(|(ngram, weights)| TagNgramData { ngram, weights }) .collect(), - }, + ), + bias, }) } + + pub fn train(self, epsilon: f64, cost: f64, solver: SolverType) -> Result> { + let mut tag_models = vec![]; + liblinear::toggle_liblinear_stdout_output(false); + let n_tokens = self.examples.len(); + for (i, (token, examples)) in self.examples.into_iter().enumerate() { + tag_models.push(Self::train_tag( + token.into(), + &examples, + epsilon, + cost, + solver, + )?); + eprint!("Tags: {}/{}\r", i, n_tokens); + } + eprintln!("Tags: {}/{}", n_tokens, n_tokens); + liblinear::toggle_liblinear_stdout_output(true); + Ok(tag_models) + } } diff --git a/vaporetto/src/trainer.rs b/vaporetto/src/trainer.rs index 3cff7d6d..77b85ff0 100644 --- a/vaporetto/src/trainer.rs +++ b/vaporetto/src/trainer.rs @@ -1,65 +1,24 @@ -use std::borrow::Borrow; -use std::collections::{BTreeMap, HashMap}; -use std::hash::Hash; -use std::str::FromStr; +use core::str::FromStr; +use alloc::collections::BTreeMap; + +use hashbrown::HashMap; + +use daachorse::DoubleArrayAhoCorasick; use liblinear::LibLinearModel; -use crate::dict_model::{DictModel, DictWeight, WordWeightRecord}; +use crate::dict_model::{DictModel, WordWeightRecord}; use crate::errors::{Result, VaporettoError}; -use crate::feature::{ - BoundaryExampleGenerator, BoundaryFeature, BytesNgramFeature, DictionaryWordFeature, - DictionaryWordPosition, StringNgramFeature, -}; use crate::model::Model; use crate::ngram_model::{NgramData, NgramModel}; -use crate::sentence::{BoundaryType, Sentence}; +use crate::sentence::{CharacterBoundary, Sentence}; use crate::tag_trainer::TagTrainer; // Bit depth for weight quantization. pub const QUANTIZE_BIT_DEPTH: u8 = 16; -pub struct Indexer { - ids: HashMap, - keys: Vec, -} - -impl Indexer -where - K: Eq + Hash, -{ - pub fn new() -> Self { - Self { - ids: HashMap::new(), - keys: vec![], - } - } - - pub fn get_id(&mut self, key: &Q) -> usize - where - K: Borrow, - Q: ToOwned + Eq + Hash, - { - if let Some(&id) = self.ids.get(key) { - id - } else { - let id = self.ids.len(); - self.keys.push(key.to_owned()); - self.ids.insert(key.to_owned(), id); - id - } - } - - pub fn len(&self) -> usize { - self.keys.len() - } - - pub fn keys(&self) -> &[K] { - &self.keys - } -} - /// Solver type. +#[cfg_attr(docsrs, doc(cfg(feature = "train")))] #[derive(Clone, Copy, Debug)] pub enum SolverType { /// L2-regularized logistic regression (primal). @@ -120,6 +79,69 @@ impl From for liblinear::SolverType { } } +#[derive(Debug, Eq, Hash, PartialEq)] +pub struct NgramFeature { + pub ngram: T, + pub rel_position: isize, +} + +#[derive(Debug, Eq, Hash, PartialEq)] +pub enum DictionaryWordPosition { + Left, + Inside, + Right, +} + +#[derive(Debug, Eq, Hash, PartialEq)] +pub struct DictionaryWordFeature { + pub(crate) length: usize, + pub(crate) position: DictionaryWordPosition, +} + +#[derive(Debug, Eq, Hash, PartialEq)] +enum BoundaryFeature<'a> { + CharacterNgram(NgramFeature<&'a str>), + CharacterTypeNgram(NgramFeature<&'a [u8]>), + DictionaryWord(DictionaryWordFeature), +} + +impl<'a> BoundaryFeature<'a> { + pub const fn char_ngram(ngram: &'a str, rel_position: isize) -> Self { + Self::CharacterNgram(NgramFeature { + ngram, + rel_position, + }) + } + + pub const fn type_ngram(ngram: &'a [u8], rel_position: isize) -> Self { + Self::CharacterTypeNgram(NgramFeature { + ngram, + rel_position, + }) + } + + pub const fn dict_word_left(length: usize) -> Self { + Self::DictionaryWord(DictionaryWordFeature { + length, + position: DictionaryWordPosition::Left, + }) + } + + pub const fn dict_word_inside(length: usize) -> Self { + Self::DictionaryWord(DictionaryWordFeature { + length, + position: DictionaryWordPosition::Inside, + }) + } + + pub const fn dict_word_right(length: usize) -> Self { + Self::DictionaryWord(DictionaryWordFeature { + length, + position: DictionaryWordPosition::Right, + }) + } +} + /// Trainer. /// /// # Examples @@ -133,13 +155,13 @@ impl From for liblinear::SolverType { /// let mut train_sents = vec![]; /// let f = BufReader::new(File::open("dataset-train.txt").unwrap()); /// for (i, line) in f.lines().enumerate() { -/// train_sents.push(Sentence::from_tokenized(line.unwrap()).unwrap()); +/// train_sents.push(Sentence::from_tokenized(&line.unwrap()).unwrap()); /// } /// /// let dict: Vec = vec![]; -/// let mut trainer = Trainer::new(3, 3, 3, 3, &dict, 0).unwrap(); +/// let mut trainer = Trainer::new(3, 3, 3, 3, dict, 0).unwrap(); /// for (i, s) in train_sents.iter().enumerate() { -/// trainer.push_sentence(s); +/// trainer.add_example(&s); /// } /// /// let model = trainer.train(0.01, 1., SolverType::L1RegularizedL2LossSVC).unwrap(); @@ -148,118 +170,154 @@ impl From for liblinear::SolverType { /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "train")))] pub struct Trainer<'a> { - dictionary: Vec, - example_generator: BoundaryExampleGenerator, char_window_size: u8, + char_ngram_size: u8, type_window_size: u8, - dict_max_word_size: u8, - feature_ids: Indexer>, + type_ngram_size: u8, + feature_ids: HashMap, u32>, + dict_words: Vec, + dict_pma: Option, + dict_word_max_len: u8, xs: Vec>, ys: Vec, + tag_trainer: TagTrainer<'a>, } impl<'a> Trainer<'a> { - /// Creates a new dataset manager. + /// Creates a new trainer. /// /// # Arguments /// - /// * `char_ngram_size` - The character n-gram length. /// * `char_window_size` - The character window size. - /// * `type_ngram_size` - The character type n-gram length. + /// * `char_ngram_size` - The character n-gram length. /// * `type_window_size` - The character type window size. - /// * `dictionary` - A word dictionary. - /// * `dict_max_word_size` - Dictionary words greater than this value will be grouped together. - /// - /// # Returns - /// - /// A dataset manager. + /// * `type_ngram_size` - The character type n-gram length. + /// * `dict_words` - A word dictionary. + /// * `dict_word_max_len` - Dictionary words longer than this value will be grouped together, + /// where the length is in characters. /// /// # Errors /// /// If invalid parameters are given, an error variant will be returned. - pub fn new( - char_ngram_size: u8, + pub fn new( char_window_size: u8, - type_ngram_size: u8, + char_ngram_size: u8, type_window_size: u8, - dictionary: D, - dict_max_word_size: u8, - ) -> Result - where - D: AsRef<[P]>, - P: AsRef<[u8]> + AsRef, - { + type_ngram_size: u8, + dict_words: Vec, + dict_word_max_len: u8, + ) -> Result { + let dict_pma = if dict_words.is_empty() { + None + } else { + Some( + DoubleArrayAhoCorasick::new(&dict_words) + .map_err(|e| VaporettoError::invalid_argument("dict_words", e.to_string()))?, + ) + }; Ok(Self { - dictionary: dictionary - .as_ref() - .iter() - .map(|word| (word.as_ref() as &str).to_string()) - .collect(), - example_generator: BoundaryExampleGenerator::new( - char_ngram_size, - type_ngram_size, - char_window_size, - type_window_size, - Some(dictionary.as_ref()).filter(|d| !d.is_empty()), - dict_max_word_size, - )?, char_window_size, + char_ngram_size, type_window_size, - dict_max_word_size, - feature_ids: Indexer::new(), + type_ngram_size, + feature_ids: HashMap::new(), + dict_words, + dict_pma, + dict_word_max_len, xs: vec![], ys: vec![], - tag_trainer: TagTrainer::new(char_ngram_size, char_window_size), + tag_trainer: TagTrainer::new( + char_window_size, + char_ngram_size, + type_window_size, + type_ngram_size, + ), }) } - /// Adds a sentence to the dataset. - /// - /// # Arguments - /// - /// * `s` - A sentence. - /// - /// # Errors - /// - /// [`VaporettoError::InvalidArgument`] will be returned if the maximum number of feature has - /// been reached. - pub fn push_sentence(&mut self, s: &'a Sentence) -> Result<()> { - let examples = self.example_generator.generate(s)?; - for example in examples { - let mut feature_ids = BTreeMap::new(); - for f in &example.features { - let fid = self.feature_ids.get_id(f); - *feature_ids - .entry((fid + 1).try_into().unwrap()) - .or_insert(0.0) += 1.0; + fn gen_features<'b>( + &self, + sentence: &'a Sentence<'a, 'b>, + examples: &mut Vec<(Vec>, CharacterBoundary)>, + ) { + for (i, &b) in sentence.boundaries().iter().enumerate() { + let mut features = vec![]; + // adds character n-gram features + for n in 0..self.char_ngram_size { + for j in (i + 1).saturating_sub(self.char_window_size.into()) + ..(i + 1 + usize::from(self.char_window_size)) + .min(sentence.len()) + .saturating_sub(n.into()) + { + features.push(BoundaryFeature::char_ngram( + sentence.text_substring(j, j + usize::from(n) + 1), + isize::try_from(j).unwrap() - isize::try_from(i).unwrap() - 1, + )); + } + } + // adds type n-gram features + for n in 0..self.type_ngram_size { + for j in (i + 1).saturating_sub(self.type_window_size.into()) + ..(i + 1 + usize::from(self.type_window_size)) + .min(sentence.len()) + .saturating_sub(n.into()) + { + features.push(BoundaryFeature::type_ngram( + &sentence.char_types()[j..j + usize::from(n) + 1], + isize::try_from(j).unwrap() - isize::try_from(i).unwrap() - 1, + )); + } + } + examples.push((features, b)); + } + // adds dictionary features + if let Some(pma) = self.dict_pma.as_ref() { + for m in pma.find_overlapping_iter(sentence.text.as_ref()) { + debug_assert!(sentence.text.is_char_boundary(m.start())); + let start = unsafe { sentence.str_to_char_pos(m.start()) }; + debug_assert!(sentence.text.is_char_boundary(m.end())); + let end = unsafe { sentence.str_to_char_pos(m.end()) }; + let length = (end - start).min(usize::from(self.dict_word_max_len)); + if start != 0 { + examples[start - 1] + .0 + .push(BoundaryFeature::dict_word_left(length)); + } + for example in &mut examples[start..end - 1] { + example.0.push(BoundaryFeature::dict_word_inside(length)); + } + if end != sentence.len() { + examples[end - 1] + .0 + .push(BoundaryFeature::dict_word_right(length)); + } } - self.xs.push(feature_ids.into_iter().collect()); - self.ys.push(f64::from(example.label as u8)); } - self.tag_trainer.push_sentence(s)?; - Ok(()) } - /// Gets the number of features. - /// - /// # Returns - /// - /// The number of features. - pub fn n_features(&self) -> usize { - self.feature_ids.len() - } + /// Adds a sentence to the trainer. + pub fn add_example<'b>(&mut self, sentence: &'a Sentence<'a, 'b>) { + let mut examples = vec![]; + self.gen_features(sentence, &mut examples); + for (features, b) in examples { + let mut feature_vector = HashMap::new(); + for feature in features { + let new_id = self.feature_ids.len() + 1; + let feature_id = *self + .feature_ids + .entry(feature) + .or_insert(new_id.try_into().unwrap()); + *feature_vector.entry(feature_id).or_insert(0f64) += 1f64; + } + self.xs.push(feature_vector.into_iter().collect()); + self.ys.push(f64::from(b as u8)); + } - /// Gets the number of tag features. - /// - /// # Returns - /// - /// The number of tag features. - pub fn n_tag_features(&self) -> usize { - self.tag_trainer.n_features() + self.tag_trainer.add_example(sentence); } - /// Trains word boundaries. + /// Trains word boundaries and tags. /// /// # Arguments /// @@ -267,9 +325,9 @@ impl<'a> Trainer<'a> { /// * `cost` - The parameter C. /// * `solver` - Solver type. /// - /// # Returns + /// # Errors /// - /// A trained model. + /// If the solver returns an error, that will be propagated. pub fn train(self, epsilon: f64, cost: f64, solver: SolverType) -> Result { let mut builder = liblinear::Builder::new(); let training_input = liblinear::util::TrainingInput::from_sparse_features(self.ys, self.xs) @@ -288,33 +346,36 @@ impl<'a> Trainer<'a> { model .labels() .iter() - .position(|&cls| BoundaryType::WordBoundary as i32 == cls) + .position(|&cls| CharacterBoundary::WordBoundary as i32 == cls) .unwrap(), )?; let bias = model.label_bias(wb_idx); - // Uses BTreeMap to increase compression ratio. - let mut char_ngram_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); - let mut type_ngram_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); - let mut dict_weights = vec![DictWeight::default(); self.dict_max_word_size.into()]; - let mut weight_max = bias.abs(); for fid in 0..model.num_features() { - let weight = model.feature_coefficient(i32::try_from(fid)?, wb_idx).abs(); - if weight > weight_max { - weight_max = weight; - } + let weight = model + .feature_coefficient(i32::try_from(fid + 1)?, wb_idx) + .abs(); + weight_max = weight_max.max(weight); } let quantize_multiplier = weight_max / f64::from((1 << (QUANTIZE_BIT_DEPTH - 1)) - 1); if quantize_multiplier == 0. { return Err(VaporettoError::invalid_model("all weights are zero")); } + // Uses BTreeMap to improve compression ratio. + let mut char_ngram_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); + let mut type_ngram_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); + let mut dict_weights = vec![]; + for i in 0..usize::from(self.dict_word_max_len) { + dict_weights.push(vec![0; i + 2]); + } + let bias = unsafe { (bias / quantize_multiplier).to_int_unchecked::() }; - for (fid, feature) in self.feature_ids.keys().iter().enumerate() { - let raw_weight = model.feature_coefficient(i32::try_from(fid)? + 1, wb_idx); + for (feature, fid) in self.feature_ids { + let raw_weight = model.feature_coefficient(i32::try_from(fid)?, wb_idx); let weight = unsafe { (raw_weight / quantize_multiplier).to_int_unchecked::() }; if weight == 0 { @@ -322,71 +383,75 @@ impl<'a> Trainer<'a> { } match feature { - BoundaryFeature::CharacterNgram(StringNgramFeature { - rel_position, + BoundaryFeature::CharacterNgram(NgramFeature { ngram, + rel_position, }) => { let len = ngram.chars().count(); - let pos = - isize::from(self.char_window_size) - isize::try_from(len)? - rel_position; - if let Some(weights) = char_ngram_weights.get_mut(*ngram) { - weights[pos as usize] = weight; + let pos = usize::try_from( + isize::from(self.char_window_size) - isize::try_from(len)? - rel_position, + ) + .unwrap(); + if let Some(weights) = char_ngram_weights.get_mut(ngram) { + weights[pos] = weight; } else { let mut weights = vec![0; usize::from(self.char_window_size) * 2 - len + 1]; - weights[pos as usize] = weight; + weights[pos] = weight; char_ngram_weights.insert(ngram.to_string(), weights); } } - BoundaryFeature::CharacterTypeNgram(BytesNgramFeature { - rel_position, + BoundaryFeature::CharacterTypeNgram(NgramFeature { ngram, + rel_position, }) => { let len = ngram.len(); - let pos = - isize::from(self.char_window_size) - isize::try_from(len)? - rel_position; - if let Some(weights) = type_ngram_weights.get_mut(*ngram) { - weights[pos as usize] = weight; + let pos = usize::try_from( + isize::from(self.char_window_size) - isize::try_from(len)? - rel_position, + ) + .unwrap(); + if let Some(weights) = type_ngram_weights.get_mut(ngram) { + weights[pos] = weight; } else { let mut weights = vec![0; usize::from(self.char_window_size) * 2 - len + 1]; - weights[pos as usize] = weight; + weights[pos] = weight; type_ngram_weights.insert(ngram.to_vec(), weights); } } - BoundaryFeature::DictionaryWord(DictionaryWordFeature { position, length }) => { + BoundaryFeature::DictionaryWord(DictionaryWordFeature { length, position }) => { + let weights = &mut dict_weights[length - 1]; match position { - DictionaryWordPosition::Right => dict_weights[length - 1].right = weight, - DictionaryWordPosition::Inside => dict_weights[length - 1].inside = weight, - DictionaryWordPosition::Left => dict_weights[length - 1].left = weight, + DictionaryWordPosition::Left => *weights.first_mut().unwrap() = weight, + DictionaryWordPosition::Inside => weights[1..length - 1].fill(weight), + DictionaryWordPosition::Right => *weights.last_mut().unwrap() = weight, } } - }; + } } - let tag_model = self.tag_trainer.train(epsilon, cost, solver)?; + + let tag_models = self.tag_trainer.train(epsilon, cost, solver)?; + Ok(Model::new( - NgramModel { - data: char_ngram_weights + NgramModel( + char_ngram_weights .into_iter() .map(|(ngram, weights)| NgramData { ngram, weights }) .collect(), - }, - NgramModel { - data: type_ngram_weights + ), + NgramModel( + type_ngram_weights .into_iter() .map(|(ngram, weights)| NgramData { ngram, weights }) .collect(), - }, + ), DictModel::new( - self.dictionary + self.dict_words .into_iter() .map(|word| { let word_len = word.chars().count(); let idx = word_len.min(dict_weights.len()) - 1; - let mut weights = vec![dict_weights[idx].inside; word_len + 1]; - *weights.first_mut().unwrap() = dict_weights[idx].right; - *weights.last_mut().unwrap() = dict_weights[idx].left; WordWeightRecord { word, - weights, + weights: dict_weights[idx].clone(), comment: "".to_string(), } }) @@ -395,7 +460,386 @@ impl<'a> Trainer<'a> { bias, self.char_window_size, self.type_window_size, - tag_model, + tag_models, )) } + + /// Returns the number of boundary features. + pub fn n_features(&self) -> usize { + self.feature_ids.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sentence::CharacterBoundary::*; + use crate::sentence::CharacterType::*; + + #[test] + fn check_features_3322() { + let s = Sentence::from_tokenized("これ は テスト です").unwrap(); + let trainer = Trainer::new(3, 3, 2, 2, vec![], 4).unwrap(); + let mut examples = vec![]; + trainer.gen_features(&s, &mut examples); + + // こ-れ + assert_eq!( + vec![ + BoundaryFeature::char_ngram("こ", -1), + BoundaryFeature::char_ngram("れ", 0), + BoundaryFeature::char_ngram("は", 1), + BoundaryFeature::char_ngram("テ", 2), + BoundaryFeature::char_ngram("これ", -1), + BoundaryFeature::char_ngram("れは", 0), + BoundaryFeature::char_ngram("はテ", 1), + BoundaryFeature::char_ngram("これは", -1), + BoundaryFeature::char_ngram("れはテ", 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8], 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], 1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], 0), + ], + examples[0].0, + ); + assert_eq!(NotWordBoundary, examples[0].1); + + // れ|は + assert_eq!( + vec![ + BoundaryFeature::char_ngram("こ", -2), + BoundaryFeature::char_ngram("れ", -1), + BoundaryFeature::char_ngram("は", 0), + BoundaryFeature::char_ngram("テ", 1), + BoundaryFeature::char_ngram("ス", 2), + BoundaryFeature::char_ngram("これ", -2), + BoundaryFeature::char_ngram("れは", -1), + BoundaryFeature::char_ngram("はテ", 0), + BoundaryFeature::char_ngram("テス", 1), + BoundaryFeature::char_ngram("これは", -2), + BoundaryFeature::char_ngram("れはテ", -1), + BoundaryFeature::char_ngram("はテス", 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8], 0), + BoundaryFeature::type_ngram(&[Katakana as u8], 1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Katakana as u8], 0), + ], + examples[1].0, + ); + assert_eq!(WordBoundary, examples[1].1); + + // は|テ + assert_eq!( + vec![ + BoundaryFeature::char_ngram("こ", -3), + BoundaryFeature::char_ngram("れ", -2), + BoundaryFeature::char_ngram("は", -1), + BoundaryFeature::char_ngram("テ", 0), + BoundaryFeature::char_ngram("ス", 1), + BoundaryFeature::char_ngram("ト", 2), + BoundaryFeature::char_ngram("これ", -3), + BoundaryFeature::char_ngram("れは", -2), + BoundaryFeature::char_ngram("はテ", -1), + BoundaryFeature::char_ngram("テス", 0), + BoundaryFeature::char_ngram("スト", 1), + BoundaryFeature::char_ngram("これは", -3), + BoundaryFeature::char_ngram("れはテ", -2), + BoundaryFeature::char_ngram("はテス", -1), + BoundaryFeature::char_ngram("テスト", 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8], 0), + BoundaryFeature::type_ngram(&[Katakana as u8], 1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8, Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], 0), + ], + examples[2].0, + ); + assert_eq!(WordBoundary, examples[2].1); + + // テ-ス + assert_eq!( + vec![ + BoundaryFeature::char_ngram("れ", -3), + BoundaryFeature::char_ngram("は", -2), + BoundaryFeature::char_ngram("テ", -1), + BoundaryFeature::char_ngram("ス", 0), + BoundaryFeature::char_ngram("ト", 1), + BoundaryFeature::char_ngram("で", 2), + BoundaryFeature::char_ngram("れは", -3), + BoundaryFeature::char_ngram("はテ", -2), + BoundaryFeature::char_ngram("テス", -1), + BoundaryFeature::char_ngram("スト", 0), + BoundaryFeature::char_ngram("トで", 1), + BoundaryFeature::char_ngram("れはテ", -3), + BoundaryFeature::char_ngram("はテス", -2), + BoundaryFeature::char_ngram("テスト", -1), + BoundaryFeature::char_ngram("ストで", 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8], 0), + BoundaryFeature::type_ngram(&[Katakana as u8], 1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], 0), + ], + examples[3].0, + ); + assert_eq!(NotWordBoundary, examples[3].1); + + // ス-ト + assert_eq!( + vec![ + BoundaryFeature::char_ngram("は", -3), + BoundaryFeature::char_ngram("テ", -2), + BoundaryFeature::char_ngram("ス", -1), + BoundaryFeature::char_ngram("ト", 0), + BoundaryFeature::char_ngram("で", 1), + BoundaryFeature::char_ngram("す", 2), + BoundaryFeature::char_ngram("はテ", -3), + BoundaryFeature::char_ngram("テス", -2), + BoundaryFeature::char_ngram("スト", -1), + BoundaryFeature::char_ngram("トで", 0), + BoundaryFeature::char_ngram("です", 1), + BoundaryFeature::char_ngram("はテス", -3), + BoundaryFeature::char_ngram("テスト", -2), + BoundaryFeature::char_ngram("ストで", -1), + BoundaryFeature::char_ngram("トです", 0), + BoundaryFeature::type_ngram(&[Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8], 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], 1), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8, Hiragana as u8], 0), + ], + examples[4].0, + ); + assert_eq!(NotWordBoundary, examples[4].1); + + // ト|で + assert_eq!( + vec![ + BoundaryFeature::char_ngram("テ", -3), + BoundaryFeature::char_ngram("ス", -2), + BoundaryFeature::char_ngram("ト", -1), + BoundaryFeature::char_ngram("で", 0), + BoundaryFeature::char_ngram("す", 1), + BoundaryFeature::char_ngram("テス", -3), + BoundaryFeature::char_ngram("スト", -2), + BoundaryFeature::char_ngram("トで", -1), + BoundaryFeature::char_ngram("です", 0), + BoundaryFeature::char_ngram("テスト", -3), + BoundaryFeature::char_ngram("ストで", -2), + BoundaryFeature::char_ngram("トです", -1), + BoundaryFeature::type_ngram(&[Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8], 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], 1), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8, Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], 0), + ], + examples[5].0, + ); + assert_eq!(WordBoundary, examples[5].1); + + // で-す + assert_eq!( + vec![ + BoundaryFeature::char_ngram("ス", -3), + BoundaryFeature::char_ngram("ト", -2), + BoundaryFeature::char_ngram("で", -1), + BoundaryFeature::char_ngram("す", 0), + BoundaryFeature::char_ngram("スト", -3), + BoundaryFeature::char_ngram("トで", -2), + BoundaryFeature::char_ngram("です", -1), + BoundaryFeature::char_ngram("ストで", -3), + BoundaryFeature::char_ngram("トです", -2), + BoundaryFeature::type_ngram(&[Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8], 0), + BoundaryFeature::type_ngram(&[Katakana as u8, Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -1), + ], + examples[6].0, + ); + assert_eq!(NotWordBoundary, examples[6].1); + } + + #[test] + fn check_features_2222_dict() { + let s = Sentence::from_tokenized("これ は テスト です").unwrap(); + let trainer = Trainer::new( + 2, + 2, + 2, + 2, + vec!["これ".into(), "これは".into(), "テスト".into()], + 4, + ) + .unwrap(); + let mut examples = vec![]; + trainer.gen_features(&s, &mut examples); + + // こ-れ + assert_eq!( + vec![ + BoundaryFeature::char_ngram("こ", -1), + BoundaryFeature::char_ngram("れ", 0), + BoundaryFeature::char_ngram("は", 1), + BoundaryFeature::char_ngram("これ", -1), + BoundaryFeature::char_ngram("れは", 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8], 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], 1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], 0), + BoundaryFeature::dict_word_inside(2), + BoundaryFeature::dict_word_inside(3), + ], + examples[0].0, + ); + assert_eq!(NotWordBoundary, examples[0].1); + + // れ|は + assert_eq!( + vec![ + BoundaryFeature::char_ngram("こ", -2), + BoundaryFeature::char_ngram("れ", -1), + BoundaryFeature::char_ngram("は", 0), + BoundaryFeature::char_ngram("テ", 1), + BoundaryFeature::char_ngram("これ", -2), + BoundaryFeature::char_ngram("れは", -1), + BoundaryFeature::char_ngram("はテ", 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8], 0), + BoundaryFeature::type_ngram(&[Katakana as u8], 1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Katakana as u8], 0), + BoundaryFeature::dict_word_right(2), + BoundaryFeature::dict_word_inside(3), + ], + examples[1].0, + ); + assert_eq!(WordBoundary, examples[1].1); + + // は|テ + assert_eq!( + vec![ + BoundaryFeature::char_ngram("れ", -2), + BoundaryFeature::char_ngram("は", -1), + BoundaryFeature::char_ngram("テ", 0), + BoundaryFeature::char_ngram("ス", 1), + BoundaryFeature::char_ngram("れは", -2), + BoundaryFeature::char_ngram("はテ", -1), + BoundaryFeature::char_ngram("テス", 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8], 0), + BoundaryFeature::type_ngram(&[Katakana as u8], 1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8, Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], 0), + BoundaryFeature::dict_word_right(3), + BoundaryFeature::dict_word_left(3), + ], + examples[2].0, + ); + assert_eq!(WordBoundary, examples[2].1); + + // テ-ス + assert_eq!( + vec![ + BoundaryFeature::char_ngram("は", -2), + BoundaryFeature::char_ngram("テ", -1), + BoundaryFeature::char_ngram("ス", 0), + BoundaryFeature::char_ngram("ト", 1), + BoundaryFeature::char_ngram("はテ", -2), + BoundaryFeature::char_ngram("テス", -1), + BoundaryFeature::char_ngram("スト", 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8], 0), + BoundaryFeature::type_ngram(&[Katakana as u8], 1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], 0), + BoundaryFeature::dict_word_inside(3), + ], + examples[3].0, + ); + assert_eq!(NotWordBoundary, examples[3].1); + + // ス-ト + assert_eq!( + vec![ + BoundaryFeature::char_ngram("テ", -2), + BoundaryFeature::char_ngram("ス", -1), + BoundaryFeature::char_ngram("ト", 0), + BoundaryFeature::char_ngram("で", 1), + BoundaryFeature::char_ngram("テス", -2), + BoundaryFeature::char_ngram("スト", -1), + BoundaryFeature::char_ngram("トで", 0), + BoundaryFeature::type_ngram(&[Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8], 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], 1), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Katakana as u8, Hiragana as u8], 0), + BoundaryFeature::dict_word_inside(3), + ], + examples[4].0, + ); + assert_eq!(NotWordBoundary, examples[4].1); + + // ト|で + assert_eq!( + vec![ + BoundaryFeature::char_ngram("ス", -2), + BoundaryFeature::char_ngram("ト", -1), + BoundaryFeature::char_ngram("で", 0), + BoundaryFeature::char_ngram("す", 1), + BoundaryFeature::char_ngram("スト", -2), + BoundaryFeature::char_ngram("トで", -1), + BoundaryFeature::char_ngram("です", 0), + BoundaryFeature::type_ngram(&[Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8], 0), + BoundaryFeature::type_ngram(&[Hiragana as u8], 1), + BoundaryFeature::type_ngram(&[Katakana as u8, Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Katakana as u8, Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], 0), + BoundaryFeature::dict_word_right(3), + ], + examples[5].0, + ); + assert_eq!(WordBoundary, examples[5].1); + + // で-す + assert_eq!( + vec![ + BoundaryFeature::char_ngram("ト", -2), + BoundaryFeature::char_ngram("で", -1), + BoundaryFeature::char_ngram("す", 0), + BoundaryFeature::char_ngram("トで", -2), + BoundaryFeature::char_ngram("です", -1), + BoundaryFeature::type_ngram(&[Katakana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8], -1), + BoundaryFeature::type_ngram(&[Hiragana as u8], 0), + BoundaryFeature::type_ngram(&[Katakana as u8, Hiragana as u8], -2), + BoundaryFeature::type_ngram(&[Hiragana as u8, Hiragana as u8], -1), + ], + examples[6].0, + ); + assert_eq!(NotWordBoundary, examples[6].1); + } } diff --git a/vaporetto/src/type_scorer.rs b/vaporetto/src/type_scorer.rs index cb024b14..1352b28a 100644 --- a/vaporetto/src/type_scorer.rs +++ b/vaporetto/src/type_scorer.rs @@ -1,74 +1,70 @@ +mod boundary_scorer; + +#[cfg(feature = "tag-prediction")] +mod boundary_tag_scorer; + +#[cfg(feature = "cache-type-score")] +mod boundary_scorer_cache; + use core::cell::RefCell; +use core::ops::AddAssign; use alloc::collections::BTreeMap; use alloc::vec::Vec; -use bincode::{de::BorrowDecoder, error::DecodeError, BorrowDecode, Decode, Encode}; -use daachorse::DoubleArrayAhoCorasick; +use bincode::{BorrowDecode, Encode}; -use crate::errors::{Result, VaporettoError}; +use crate::errors::Result; use crate::ngram_model::NgramModel; use crate::sentence::Sentence; -use crate::utils::AddWeight; - -/// WARNING: The decode feature is inherently unsafe. Do not publish this feature outside this -/// crate. -#[derive(BorrowDecode, Encode)] -pub enum TypeScorer { - Pma(TypeScorerPma), - - #[cfg(feature = "cache-type-score")] - Cache(TypeScorerCache), -} -impl TypeScorer { - pub fn new(model: NgramModel>, window_size: u8) -> Result { - #[cfg(feature = "cache-type-score")] - let scorer = if window_size <= 3 { - Self::Cache(TypeScorerCache::new(model, window_size)?) - } else { - Self::Pma(TypeScorerPma::new(model, window_size)?) - }; +#[cfg(feature = "tag-prediction")] +use crate::ngram_model::TagNgramModel; - #[cfg(not(feature = "cache-type-score"))] - let scorer = Self::Pma(TypeScorerPma::new(model, window_size)?); +use boundary_scorer::TypeScorerBoundary; - Ok(scorer) - } +#[cfg(feature = "cache-type-score")] +use boundary_scorer_cache::TypeScorerBoundaryCache; - pub fn add_scores(&self, sentence: &Sentence, padding: u8, ys: &mut [i32]) { - match self { - TypeScorer::Pma(pma) => pma.add_scores(sentence, padding, ys), +#[cfg(feature = "tag-prediction")] +use boundary_tag_scorer::TypeScorerBoundaryTag; - #[cfg(feature = "cache-type-score")] - TypeScorer::Cache(cache) => cache.add_scores(sentence, &mut ys[padding.into()..]), - } - } -} +// If the cache-type-score feature is enabled and the window size of character type features is +// less than or equal to this value, character type scores are cached. +#[cfg(feature = "cache-type-score")] +const CACHE_MAX_WINDOW_SIZE: u8 = 3; -pub struct TypeScorerPma { - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: u8, +#[derive(Default)] +struct TypeWeightMerger { + map: BTreeMap, RefCell<(W, bool)>>, } -impl TypeScorerPma { - pub fn new(model: NgramModel>, window_size: u8) -> Result { - // key: ngram, value: (weight, check) - let mut weights_map: BTreeMap, RefCell<(Vec, bool)>> = BTreeMap::new(); - - for d in model.data { - weights_map.insert(d.ngram, RefCell::new((d.weights, false))); +impl TypeWeightMerger +where + for<'a> W: AddAssign<&'a W>, +{ + pub fn add(&mut self, ngram: V, weight: W) + where + V: Into> + AsRef<[u8]>, + { + if let Some(data) = self.map.get_mut(ngram.as_ref()) { + let (prev_weight, _) = &mut *data.borrow_mut(); + *prev_weight += &weight; + } else { + self.map.insert(ngram.into(), RefCell::new((weight, false))); } + } + #[must_use] + pub fn merge(self) -> Vec<(Vec, W)> { let mut stack = vec![]; - for (ngram, data) in &weights_map { + for (ngram, data) in &self.map { if data.borrow().1 { continue; } stack.push(data); for j in 1..ngram.len() { - if let Some(data) = weights_map.get(&ngram[j..]) { + if let Some(data) = self.map.get(&ngram[j..]) { stack.push(data); if data.borrow().1 { break; @@ -78,172 +74,401 @@ impl TypeScorerPma { let mut data_from = stack.pop().unwrap(); data_from.borrow_mut().1 = true; while let Some(data_to) = stack.pop() { - let mut new_weight = data_from.borrow().0.clone(); - for (w1, w2) in new_weight.iter_mut().zip(&data_to.borrow().0) { - *w1 += w2; - } - let new_data = (new_weight, true); - *data_to.borrow_mut() = new_data; + let data_to_ref = &mut data_to.borrow_mut(); + data_to_ref.1 = true; + data_to_ref.0 += &data_from.borrow().0; data_from = data_to; } } - let mut ngrams = vec![]; - let mut weights = vec![]; - for (ngram, data) in weights_map { - ngrams.push(ngram); - weights.push(data.into_inner().0); - } - let pma = DoubleArrayAhoCorasick::new(ngrams) - .map_err(|_| VaporettoError::invalid_model("invalid character type n-grams"))?; - Ok(Self { - pma, - weights, - window_size, - }) - } - - pub fn add_scores(&self, sentence: &Sentence, padding: u8, ys: &mut [i32]) { - for m in self - .pma - .find_overlapping_no_suffix_iter(&sentence.char_type) - { - let offset = usize::from(padding) + m.end() - usize::from(self.window_size) - 1; - // Both the weights and the PMA always have the same number of items. - // Therefore, the following code is safe. - let weights = unsafe { self.weights.get_unchecked(m.value()) }; - weights.add_weight(ys, offset); - } + self.map + .into_iter() + .map(|(ngram, weight)| (ngram, weight.into_inner().0)) + .collect() } } -impl<'de> BorrowDecode<'de> for TypeScorerPma { - /// WARNING: This function is inherently unsafe. Do not publish this function outside this - /// crate. - fn borrow_decode>(decoder: &mut D) -> Result { - let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; - let (pma, _) = - unsafe { DoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; - Ok(Self { - pma, - weights: Decode::decode(decoder)?, - window_size: Decode::decode(decoder)?, - }) - } -} +/// WARNING: Decoding is inherently unsafe. Do not publish this struct outside this +/// crate. +#[derive(BorrowDecode, Encode)] +pub enum TypeScorer { + Boundary(TypeScorerBoundary), -impl Encode for TypeScorerPma { - fn encode( - &self, - encoder: &mut E, - ) -> Result<(), bincode::error::EncodeError> { - let pma_data = self.pma.serialize_to_vec(); - Encode::encode(&pma_data, encoder)?; - Encode::encode(&self.weights, encoder)?; - Encode::encode(&self.window_size, encoder)?; - Ok(()) - } -} + #[cfg(feature = "cache-type-score")] + BoundaryCache(TypeScorerBoundaryCache), -#[cfg(feature = "cache-type-score")] -#[derive(Decode, Encode)] -pub struct TypeScorerCache { - scores: Vec, - window_size: u8, - sequence_mask: usize, + #[cfg(feature = "tag-prediction")] + BoundaryTag(TypeScorerBoundaryTag), } -#[cfg(feature = "cache-type-score")] -impl TypeScorerCache { - pub fn new(model: NgramModel>, window_size: u8) -> Result { - let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)) - .map_err(|_| VaporettoError::invalid_model("invalid character type n-grams"))?; - let mut weights = vec![]; - for d in model.data { - if d.weights.len() <= 2 * usize::from(window_size) - d.ngram.len() { - return Err(VaporettoError::invalid_model( - "invalid size of weight vector", - )); - } - weights.push(d.weights); +impl TypeScorer { + pub fn new( + ngram_model: NgramModel>, + window_size: u8, + #[cfg(feature = "tag-prediction")] tag_ngram_model: Vec>>, + ) -> Result> { + if ngram_model.0.is_empty() || window_size == 0 { + return Ok(None); } - let sequence_size = u16::from(window_size) * 2; - let all_sequences = ALPHABET_SIZE.pow(sequence_size.into()); - - let mut sequence = vec![0u8; sequence_size.into()]; - let mut scores = vec![0; all_sequences]; - - for (i, score) in scores.iter_mut().enumerate() { - if !Self::seqid_to_seq(i, &mut sequence) { - continue; - } - let mut y = 0; - for m in pma.find_overlapping_iter(&sequence) { - y += weights[m.value()][usize::from(sequence_size) - m.end()]; + #[cfg(feature = "tag-prediction")] + if tag_ngram_model.is_empty() { + match window_size { + #[cfg(feature = "cache-type-score")] + 0..=CACHE_MAX_WINDOW_SIZE => Ok(Some(Self::BoundaryCache( + TypeScorerBoundaryCache::new(ngram_model, window_size)?, + ))), + _ => Ok(Some(Self::Boundary(TypeScorerBoundary::new( + ngram_model, + window_size, + )?))), } - *score = y; + } else { + Ok(Some(Self::BoundaryTag(TypeScorerBoundaryTag::new( + ngram_model, + window_size, + tag_ngram_model, + )?))) } - Ok(Self { - scores, - window_size, - sequence_mask: (1 << (ALPHABET_SHIFT * usize::from(sequence_size))) - 1, - }) + #[cfg(not(feature = "tag-prediction"))] + match window_size { + #[cfg(feature = "cache-type-score")] + 0..=CACHE_MAX_WINDOW_SIZE => Ok(Some(Self::BoundaryCache( + TypeScorerBoundaryCache::new(ngram_model, window_size)?, + ))), + _ => Ok(Some(Self::Boundary(TypeScorerBoundary::new( + ngram_model, + window_size, + )?))), + } } - pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { - let mut seqid = 0; - for i in 0..self.window_size { - if let Some(ct) = sentence.char_type.get(usize::from(i)) { - seqid = self.increment_seqid(seqid, *ct); - } else { - seqid = self.increment_seqid_without_char(seqid); - }; - } - for (i, y) in ys.iter_mut().enumerate() { - if let Some(ct) = sentence.char_type.get(i + usize::from(self.window_size)) { - seqid = self.increment_seqid(seqid, *ct); - } else { - seqid = self.increment_seqid_without_char(seqid); - }; - *y += self.get_score(seqid); + #[inline] + pub fn add_scores<'a, 'b>(&self, sentence: &mut Sentence<'a, 'b>) { + match self { + Self::Boundary(scorer) => scorer.add_scores(sentence), + + #[cfg(feature = "cache-type-score")] + Self::BoundaryCache(scorer) => scorer.add_scores(sentence), + + #[cfg(feature = "tag-prediction")] + Self::BoundaryTag(scorer) => scorer.add_scores(sentence), } } - #[allow(clippy::cast_possible_truncation)] - fn seqid_to_seq(mut seqid: usize, sequence: &mut [u8]) -> bool { - for type_id in sequence.iter_mut().rev() { - *type_id = (seqid & ALPHABET_MASK) as u8; - if usize::from(*type_id) == ALPHABET_MASK { - return false; // invalid - } - seqid >>= ALPHABET_SHIFT; + /// # Satety + /// + /// `token_id` must be smaller than `scorer.tag_weight.len()`. + /// `pos` must be smaller than `sentence.type_pma_states.len()`. + #[cfg(feature = "tag-prediction")] + #[inline] + pub unsafe fn add_tag_scores( + &self, + token_id: u32, + pos: usize, + sentence: &Sentence, + scores: &mut [i32], + ) { + match self { + Self::BoundaryTag(scorer) => scorer.add_tag_scores(token_id, pos, sentence, scores), + _ => panic!("unsupported"), } - assert_eq!(seqid, 0); - true } +} + +#[cfg(test)] +mod tests { + use super::*; - #[inline(always)] - fn get_score(&self, seqid: usize) -> i32 { - self.scores[seqid] + use crate::ngram_model::NgramData; + use crate::predictor::PositionalWeight; + use crate::CharacterType::*; + + use crate::predictor::WEIGHT_FIXED_LEN; + + #[cfg(feature = "tag-prediction")] + use crate::ngram_model::{TagNgramData, TagWeight}; + + #[rustfmt::skip] + #[test] + fn test_weight_merger() { + let mut merger = TypeWeightMerger::default(); + merger.add(b"eab".to_vec(), PositionalWeight::new(-3, vec![1, 2, 3, 4])); + merger.add(b"ab".to_vec(), PositionalWeight::new(-3, vec![2, 4, 6, 8, 10])); + merger.add(b"ab".to_vec(), PositionalWeight::new(-3, vec![3, 6, 9])); + merger.add(b"cd".to_vec(), PositionalWeight::new(-2, vec![4, 8, 12])); + assert_eq!( + vec![ + (b"ab".to_vec(), PositionalWeight::new(-3, vec![5, 10, 15, 8, 10])), + (b"cd".to_vec(), PositionalWeight::new(-2, vec![4, 8, 12])), + (b"eab".to_vec(), PositionalWeight::new(-3, vec![6, 12, 18, 12, 10])), + ], + merger.merge(), + ); } - #[inline(always)] - fn increment_seqid(&self, seqid: usize, char_type: u8) -> usize { - let char_id = usize::from(char_type); - debug_assert!((1..=6).contains(&char_id)); - ((seqid << ALPHABET_SHIFT) | char_id) & self.sequence_mask + #[test] + fn test_add_scores() { + // input: 我 ら は 全 世 界 の 国 民 + // n-grams: + // KH: 4 5 6 7 + // 1 2 3 4 5 6 + // KKK: 8 9 10 11 12 13 + // KK: 14 15 16 17 18 19 20 + // 14 15 16 17 18 19 20 + // 14 15 16 17 + // K: 25 26 27 28 + // 22 23 24 25 26 27 28 + // 21 22 23 24 25 26 27 28 + // 21 22 23 24 25 26 27 + // 21 22 23 24 25 + // 21 22 23 24 + let scorer = TypeScorerBoundary::new( + NgramModel(vec![ + NgramData { + ngram: vec![Kanji as u8, Hiragana as u8], + weights: vec![1, 2, 3, 4, 5, 6, 7], + }, + NgramData { + ngram: vec![Kanji as u8, Kanji as u8, Kanji as u8], + weights: vec![8, 9, 10, 11, 12, 13], + }, + NgramData { + ngram: vec![Kanji as u8, Kanji as u8], + weights: vec![14, 15, 16, 17, 18, 19, 20], + }, + NgramData { + ngram: vec![Kanji as u8], + weights: vec![21, 22, 23, 24, 25, 26, 27, 28], + }, + ]), + 4, + ) + .unwrap(); + let mut sentence = Sentence::from_raw("我らは全世界の国民").unwrap(); + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, 1); + scorer.add_scores(&mut sentence); + assert_eq!( + &[87, 135, 144, 174, 182, 192, 202, 148], + sentence.boundary_scores(), + ); } - #[inline(always)] - const fn increment_seqid_without_char(&self, seqid: usize) -> usize { - (seqid << ALPHABET_SHIFT) & self.sequence_mask + #[cfg(feature = "cache-type-score")] + #[test] + fn test_add_scores_cache_1() { + // input: 我 ら は 全 世 界 の 国 民 + // n-grams: + // KH: 3 4 5 + // 1 2 3 4 5 + // KKK: 6 7 8 9 + // KK: 10 11 12 13 14 + // 10 11 12 13 14 + // 10 11 12 + // K: 18 19 20 + // 15 16 17 18 19 20 + // 15 16 17 18 19 20 + // 15 16 17 18 19 20 + // 15 16 17 18 + // 15 16 17 + let scorer = TypeScorerBoundaryCache::new( + NgramModel(vec![ + NgramData { + ngram: vec![Kanji as u8, Hiragana as u8], + weights: vec![1, 2, 3, 4, 5], + }, + NgramData { + ngram: vec![Kanji as u8, Kanji as u8, Kanji as u8], + weights: vec![6, 7, 8, 9], + }, + NgramData { + ngram: vec![Kanji as u8, Kanji as u8], + weights: vec![10, 11, 12, 13, 14], + }, + NgramData { + ngram: vec![Kanji as u8], + weights: vec![15, 16, 17, 18, 19, 20], + }, + ]), + 3, + ) + .unwrap(); + let mut sentence = Sentence::from_raw("我らは全世界の国民").unwrap(); + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, 2); + scorer.add_scores(&mut sentence); + assert_eq!( + &[38, 66, 102, 84, 106, 139, 103, 74], + sentence.boundary_scores(), + ); } -} -#[cfg(feature = "cache-type-score")] -const ALPHABET_SIZE: usize = 8; -#[cfg(feature = "cache-type-score")] -const ALPHABET_MASK: usize = ALPHABET_SIZE - 1; -#[cfg(feature = "cache-type-score")] -const ALPHABET_SHIFT: usize = 3; + #[cfg(feature = "cache-type-score")] + #[test] + fn test_add_scores_cache_2() { + // input: 我 ら は 全 世 界 の 国 民 + // n-grams: + // KH: 2 3 + // 1 2 3 + // KKK: 4 5 + // KK: 6 7 8 + // 6 7 8 + // 6 7 + // K: 11 12 + // 9 10 11 12 + // 9 10 11 12 + // 9 10 11 12 + // 9 10 11 + // 9 10 + let scorer = TypeScorerBoundaryCache::new( + NgramModel(vec![ + NgramData { + ngram: vec![Kanji as u8, Hiragana as u8], + weights: vec![1, 2, 3], + }, + NgramData { + ngram: vec![Kanji as u8, Kanji as u8, Kanji as u8], + weights: vec![4, 5], + }, + NgramData { + ngram: vec![Kanji as u8, Kanji as u8], + weights: vec![6, 7, 8], + }, + NgramData { + ngram: vec![Kanji as u8], + weights: vec![9, 10, 11, 12], + }, + ]), + 2, + ) + .unwrap(); + let mut sentence = Sentence::from_raw("我らは全世界の国民").unwrap(); + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, 3); + scorer.add_scores(&mut sentence); + assert_eq!( + &[16, 27, 28, 50, 57, 45, 43, 31], + sentence.boundary_scores(), + ); + } + + #[cfg(feature = "tag-prediction")] + #[test] + fn test_add_scores_with_tags() { + // input: こ の 人 は 火 星 人 だ + // n-grams: + // HHK: 2 3 4 + // KH: 5 6 7 8 9 + // 5 6 7 + let scorer = TypeScorerBoundaryTag::new( + NgramModel(vec![ + NgramData { + ngram: vec![Hiragana as u8, Hiragana as u8, Kanji as u8], + weights: vec![1, 2, 3, 4], + }, + NgramData { + ngram: vec![Kanji as u8, Hiragana as u8], + weights: vec![5, 6, 7, 8, 9], + }, + ]), + 3, + vec![ + TagNgramModel(vec![ + TagNgramData { + ngram: vec![Hiragana as u8, Kanji as u8], + weights: vec![ + TagWeight { + rel_position: 0, + weights: vec![10, 11, 12], + }, + TagWeight { + rel_position: 1, + weights: vec![13, 14, 15], + }, + ], + }, + TagNgramData { + ngram: vec![Kanji as u8, Hiragana as u8], + weights: vec![ + TagWeight { + rel_position: 1, + weights: vec![16, 17, 18], + }, + TagWeight { + rel_position: 3, + weights: vec![19, 20, 21], + }, + ], + }, + TagNgramData { + ngram: vec![Kanji as u8, Kanji as u8, Kanji as u8], + weights: vec![TagWeight { + rel_position: 0, + weights: vec![22, 23, 24], + }], + }, + ]), + TagNgramModel(vec![]), + TagNgramModel(vec![ + TagNgramData { + ngram: vec![Kanji as u8, Hiragana as u8], + weights: vec![ + TagWeight { + rel_position: 0, + weights: vec![25, 26], + }, + TagWeight { + rel_position: 3, + weights: vec![27, 28], + }, + ], + }, + TagNgramData { + ngram: vec![Hiragana as u8, Kanji as u8, Kanji as u8, Kanji as u8], + weights: vec![TagWeight { + rel_position: 3, + weights: vec![29, 30], + }], + }, + ]), + ], + ) + .unwrap(); + let mut sentence = Sentence::from_raw("この人は火星人だ").unwrap(); + sentence.score_padding = WEIGHT_FIXED_LEN - 1; + sentence.boundary_scores.clear(); + sentence + .boundary_scores + .resize(sentence.score_padding * 2 + sentence.len() - 1, 1); + scorer.add_scores(&mut sentence); + assert_eq!(&[8, 10, 12, 9, 15, 7, 8], sentence.boundary_scores()); + + let mut tag_scores = [1; 8]; + unsafe { + scorer.add_tag_scores(0, 2, &sentence, &mut tag_scores); + } + assert_eq!(&[27, 29, 31, 1, 1, 1, 1, 1], &tag_scores); + + let mut tag_scores = [1; 8]; + unsafe { + scorer.add_tag_scores(0, 6, &sentence, &mut tag_scores); + } + assert_eq!(&[39, 41, 43, 1, 1, 1, 1, 1], &tag_scores); + + let mut tag_scores = [1; 8]; + unsafe { + scorer.add_tag_scores(2, 3, &sentence, &mut tag_scores); + } + assert_eq!(&[55, 57, 1, 1, 1, 1, 1, 1], &tag_scores); + } +} diff --git a/vaporetto/src/type_scorer/boundary_scorer.rs b/vaporetto/src/type_scorer/boundary_scorer.rs new file mode 100644 index 00000000..844ea704 --- /dev/null +++ b/vaporetto/src/type_scorer/boundary_scorer.rs @@ -0,0 +1,79 @@ +use alloc::vec::Vec; + +use bincode::{ + de::BorrowDecoder, + enc::Encoder, + error::{DecodeError, EncodeError}, + BorrowDecode, Decode, Encode, +}; +use daachorse::DoubleArrayAhoCorasick; + +use crate::errors::{Result, VaporettoError}; +use crate::ngram_model::NgramModel; +use crate::predictor::{PositionalWeight, WeightVector}; +use crate::sentence::Sentence; +use crate::type_scorer::TypeWeightMerger; + +pub struct TypeScorerBoundary { + pma: DoubleArrayAhoCorasick, + weights: Vec>, +} + +impl<'de> BorrowDecode<'de> for TypeScorerBoundary { + /// WARNING: This function is inherently unsafe. Do not publish this function outside this + /// crate. + fn borrow_decode>(decoder: &mut D) -> Result { + let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; + let (pma, _) = + unsafe { DoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; + Ok(Self { + pma, + weights: Decode::decode(decoder)?, + }) + } +} + +impl Encode for TypeScorerBoundary { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let pma_data = self.pma.serialize_to_vec(); + Encode::encode(&pma_data, encoder)?; + Encode::encode(&self.weights, encoder)?; + Ok(()) + } +} + +impl TypeScorerBoundary { + pub fn new(ngram_model: NgramModel>, window_size: u8) -> Result { + let mut merger = TypeWeightMerger::default(); + for d in ngram_model.0 { + let weight = PositionalWeight::new(-i16::from(window_size), d.weights); + merger.add(d.ngram, weight); + } + let mut ngrams = vec![]; + let mut weights = vec![]; + for (ngram, weight) in merger.merge() { + ngrams.push(ngram); + weights.push(weight.into()); + } + let pma = DoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; + Ok(Self { pma, weights }) + } + + #[allow(clippy::cast_possible_wrap)] + #[inline(always)] + pub fn add_scores<'a, 'b>(&self, sentence: &mut Sentence<'a, 'b>) { + for m in self + .pma + .find_overlapping_no_suffix_iter(&sentence.char_types) + { + debug_assert!(m.end() != 0 && m.end() <= sentence.char_types.len()); + debug_assert!(m.value() < self.weights.len()); + let weight = unsafe { self.weights.get_unchecked(m.value()) }; + weight.add_score( + (m.end() + sentence.score_padding - 1) as isize, + &mut sentence.boundary_scores, + ); + } + } +} diff --git a/vaporetto/src/type_scorer/boundary_scorer_cache.rs b/vaporetto/src/type_scorer/boundary_scorer_cache.rs new file mode 100644 index 00000000..868d623f --- /dev/null +++ b/vaporetto/src/type_scorer/boundary_scorer_cache.rs @@ -0,0 +1,109 @@ +use alloc::vec::Vec; + +use bincode::{Decode, Encode}; +use daachorse::DoubleArrayAhoCorasick; + +use crate::errors::{Result, VaporettoError}; +use crate::ngram_model::NgramModel; +use crate::sentence::Sentence; + +const ALPHABET_SIZE: usize = 8; +const ALPHABET_MASK: usize = ALPHABET_SIZE - 1; +const ALPHABET_SHIFT: usize = 3; + +#[derive(Decode, Encode)] +pub struct TypeScorerBoundaryCache { + scores: Vec, + window_size: u8, + sequence_mask: usize, +} + +impl TypeScorerBoundaryCache { + pub fn new(model: NgramModel>, window_size: u8) -> Result { + let pma = DoubleArrayAhoCorasick::new(model.0.iter().map(|d| &d.ngram)) + .map_err(|_| VaporettoError::invalid_model("invalid character type n-grams"))?; + let mut weights = vec![]; + for d in model.0 { + weights.push(d.weights); + } + + let sequence_size = u16::from(window_size) * 2; + let all_sequences = ALPHABET_SIZE.pow(sequence_size.into()); + + let mut sequence = vec![0u8; sequence_size.into()]; + let mut scores = vec![0; all_sequences]; + + for (i, score) in scores.iter_mut().enumerate() { + if !Self::seqid_to_seq(i, &mut sequence) { + continue; + } + let mut y = 0; + for m in pma.find_overlapping_iter(&sequence) { + if let Some(w) = weights[m.value()].get(usize::from(sequence_size) - m.end()) { + y += *w; + } + } + *score = y; + } + + Ok(Self { + scores, + window_size, + sequence_mask: (1 << (ALPHABET_SHIFT * usize::from(sequence_size))) - 1, + }) + } + + #[inline(always)] + pub fn add_scores<'a, 'b>(&self, sentence: &mut Sentence<'a, 'b>) { + sentence.type_pma_states.clear(); + let mut seqid = 0; + for i in 0..self.window_size { + if let Some(ct) = sentence.char_types.get(usize::from(i)) { + seqid = self.increment_seqid(seqid, *ct); + } else { + seqid = self.increment_seqid_without_char(seqid); + }; + } + for (i, y) in sentence.boundary_scores + [sentence.score_padding..sentence.score_padding + sentence.boundaries.len()] + .iter_mut() + .enumerate() + { + if let Some(ct) = sentence.char_types.get(i + usize::from(self.window_size)) { + seqid = self.increment_seqid(seqid, *ct); + } else { + seqid = self.increment_seqid_without_char(seqid); + }; + *y += self.get_score(seqid); + } + } + + fn seqid_to_seq(mut seqid: usize, sequence: &mut [u8]) -> bool { + for type_id in sequence.iter_mut().rev() { + *type_id = u8::try_from(seqid & ALPHABET_MASK).unwrap(); + if usize::from(*type_id) == ALPHABET_MASK { + return false; // invalid + } + seqid >>= ALPHABET_SHIFT; + } + assert_eq!(seqid, 0); + true + } + + #[inline(always)] + fn get_score(&self, seqid: usize) -> i32 { + self.scores[seqid] + } + + #[inline(always)] + fn increment_seqid(&self, seqid: usize, char_type: u8) -> usize { + let char_id = usize::from(char_type); + debug_assert!((1..=6).contains(&char_id)); + ((seqid << ALPHABET_SHIFT) | char_id) & self.sequence_mask + } + + #[inline(always)] + const fn increment_seqid_without_char(&self, seqid: usize) -> usize { + (seqid << ALPHABET_SHIFT) & self.sequence_mask + } +} diff --git a/vaporetto/src/type_scorer/boundary_tag_scorer.rs b/vaporetto/src/type_scorer/boundary_tag_scorer.rs new file mode 100644 index 00000000..22ff5301 --- /dev/null +++ b/vaporetto/src/type_scorer/boundary_tag_scorer.rs @@ -0,0 +1,152 @@ +use alloc::vec::Vec; + +use bincode::{ + de::BorrowDecoder, + enc::Encoder, + error::{DecodeError, EncodeError}, + BorrowDecode, Decode, Encode, +}; +use daachorse::DoubleArrayAhoCorasick; +use hashbrown::HashMap; + +use crate::errors::{Result, VaporettoError}; +use crate::ngram_model::{NgramModel, TagNgramModel}; +use crate::predictor::{PositionalWeight, PositionalWeightWithTag, WeightVector}; +use crate::sentence::Sentence; +use crate::type_scorer::TypeWeightMerger; +use crate::utils::SplitMix64Builder; + +pub struct TypeScorerBoundaryTag { + pma: DoubleArrayAhoCorasick, + weights: Vec>>, + tag_weight: Vec>>, +} + +impl<'de> BorrowDecode<'de> for TypeScorerBoundaryTag { + /// WARNING: This function is inherently unsafe. Do not publish this function outside this + /// crate. + fn borrow_decode>(decoder: &mut D) -> Result { + let pma_data: &[u8] = BorrowDecode::borrow_decode(decoder)?; + let (pma, _) = + unsafe { DoubleArrayAhoCorasick::deserialize_from_slice_unchecked(pma_data) }; + let tag_weight: Vec>> = Decode::decode(decoder)?; + let tag_weight = tag_weight + .into_iter() + .map(|x| x.into_iter().map(|x| x.into_iter().collect()).collect()) + .collect(); + Ok(Self { + pma, + weights: Decode::decode(decoder)?, + tag_weight, + }) + } +} + +impl Encode for TypeScorerBoundaryTag { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let pma_data = self.pma.serialize_to_vec(); + Encode::encode(&pma_data, encoder)?; + Encode::encode(&self.weights, encoder)?; + let tag_weight: Vec>> = self + .tag_weight + .iter() + .map(|x| x.iter().map(|x| x.iter().collect()).collect()) + .collect(); + Encode::encode(&tag_weight, encoder)?; + Ok(()) + } +} + +impl TypeScorerBoundaryTag { + pub fn new( + ngram_model: NgramModel>, + window_size: u8, + tag_ngram_model: Vec>>, + ) -> Result { + let mut merger = TypeWeightMerger::default(); + for d in ngram_model.0 { + let weight = PositionalWeightWithTag::with_boundary(-i16::from(window_size), d.weights); + merger.add(d.ngram, weight); + } + let mut tag_weight = + vec![ + vec![HashMap::with_hasher(SplitMix64Builder); usize::from(window_size) + 1]; + tag_ngram_model.len() + ]; + for (i, tag_model) in tag_ngram_model.into_iter().enumerate() { + for d in tag_model.0 { + for w in d.weights { + let weight = PositionalWeightWithTag::with_tag(i, w.rel_position, w.weights); + merger.add(d.ngram.as_slice(), weight); + } + } + } + let mut ngrams = vec![]; + let mut weights = vec![]; + for (i, (ngram, weight)) in merger.merge().into_iter().enumerate() { + ngrams.push(ngram); + weights.push(weight.weight.map(|w| w.into())); + for ((token_id, rel_position), weight) in weight.tag_info { + tag_weight[token_id][usize::from(rel_position)] + .insert(u32::try_from(i).unwrap(), weight.into()); + } + } + let pma = DoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("failed to build the automaton"))?; + Ok(Self { + pma, + weights, + tag_weight, + }) + } + + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_possible_wrap)] + #[inline(always)] + pub fn add_scores<'a, 'b>(&self, sentence: &mut Sentence<'a, 'b>) { + sentence.type_pma_states.clear(); + sentence.type_pma_states.resize(sentence.len(), u32::MAX); + for m in self + .pma + .find_overlapping_no_suffix_iter(&sentence.char_types) + { + debug_assert!(m.end() != 0 && m.end() <= sentence.char_types.len()); + debug_assert!(m.value() < self.weights.len()); + if let Some(weight) = unsafe { self.weights.get_unchecked(m.value()) } { + weight.add_score( + (m.end() + sentence.score_padding - 1) as isize, + &mut sentence.boundary_scores, + ); + } + debug_assert!(m.end() <= sentence.type_pma_states.len()); + unsafe { *sentence.type_pma_states.get_unchecked_mut(m.end() - 1) = m.value() as u32 }; + } + } + + /// # Satety + /// + /// `token_id` must be smaller than `scorer.tag_weight.len()`. + /// `pos` must be smaller than `sentence.type_pma_states.len()`. + #[inline(always)] + pub unsafe fn add_tag_scores( + &self, + token_id: u32, + pos: usize, + sentence: &Sentence, + scores: &mut [i32], + ) { + let tag_weight = self + .tag_weight + .get_unchecked(usize::try_from(token_id).unwrap()); + for (state_id, tag_weights) in sentence + .type_pma_states + .get_unchecked(pos..) + .iter() + .zip(tag_weight) + { + if let Some(weight) = tag_weights.get(state_id) { + weight.add_scores(scores); + } + } + } +} diff --git a/vaporetto/src/utils.rs b/vaporetto/src/utils.rs index 3a02b1fb..defe8fe3 100644 --- a/vaporetto/src/utils.rs +++ b/vaporetto/src/utils.rs @@ -1,129 +1,162 @@ -use core::cell::RefCell; +use core::hash::{BuildHasher, Hash, Hasher}; +use core::num::Wrapping; +use core::ops::{Deref, DerefMut}; -use alloc::collections::BTreeMap; -use alloc::string::{String, ToString}; use alloc::vec::Vec; #[cfg(feature = "kytea")] use std::io::{self, Read}; -use bincode::enc::write::Writer; -use bincode::error::EncodeError; +use bincode::{ + de::Decoder, + enc::{write::Writer, Encoder}, + error::{DecodeError, EncodeError}, + Decode, Encode, +}; +use hashbrown::HashMap; + +#[cfg(feature = "fix-weight-length")] +#[inline(always)] +pub const fn trim_end_zeros(mut w: &[i32]) -> &[i32] { + while let Some((&last, rest)) = w.split_last() { + if last != 0 { + break; + } + w = rest; + } + w +} -pub trait AddWeight { - fn add_weight(&self, target: &mut [i32], offset: usize); +pub struct VecWriter(pub Vec); - #[cfg(feature = "tag-prediction")] - fn add_weight_signed(&self, target: &mut [i32], offset: isize); +impl Writer for VecWriter { + fn write(&mut self, bytes: &[u8]) -> Result<(), EncodeError> { + self.0.extend_from_slice(bytes); + Ok(()) + } } -impl AddWeight for Vec { - fn add_weight(&self, ys: &mut [i32], offset: usize) { - if let Some(ys) = ys.get_mut(offset..) { - for (w, y) in self.iter().zip(ys) { - *y += w; - } - } - } +#[derive(Debug)] +pub struct SerializableHashMap(pub HashMap); - #[cfg(feature = "tag-prediction")] - fn add_weight_signed(&self, ys: &mut [i32], offset: isize) { - if offset >= 0 { - if let Some(ys) = ys.get_mut(offset as usize..) { - for (w, y) in self.iter().zip(ys) { - *y += w; - } - } - } else if let Some(ws) = self.get(-offset as usize..) { - for (w, y) in ws.iter().zip(ys.iter_mut()) { - *y += w; - } - } +impl Deref for SerializableHashMap { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.0 } } -pub trait MergableWeight { - fn from_two_weights(weight1: &Self, weight2: &Self, n_classes: usize) -> Self; +impl DerefMut for SerializableHashMap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } -pub struct WeightMerger { - map: BTreeMap>, - n_classes: usize, +impl Decode for SerializableHashMap +where + K: Encode + Decode + Eq + Hash, + V: Encode + Decode, +{ + fn decode(decoder: &mut D) -> Result { + let raw: Vec<(K, V)> = Decode::decode(decoder)?; + Ok(Self(raw.into_iter().collect())) + } } -impl WeightMerger +impl Encode for SerializableHashMap where - W: MergableWeight, + K: Encode + Decode, + V: Encode + Decode, { - pub fn new(n_classes: usize) -> Self { - Self { - map: BTreeMap::new(), - n_classes, - } + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let raw: Vec<(&K, &V)> = self.0.iter().collect(); + Encode::encode(&raw, encoder)?; + Ok(()) } +} - pub fn add(&mut self, ngram: &str, weight: W) { - if let Some(data) = self.map.get_mut(ngram) { - let (prev_weight, _) = &mut *data.borrow_mut(); - *prev_weight = W::from_two_weights(&weight, prev_weight, self.n_classes); - } else { - self.map - .insert(ngram.to_string(), RefCell::new((weight, false))); - } +// Copied from https://prng.di.unimi.it/splitmix64.c +pub struct SplitMix64 { + x: Wrapping, +} + +impl Hasher for SplitMix64 { + #[inline(always)] + fn finish(&self) -> u64 { + let mut z = self.x; + z = (z ^ (z >> 30)) * Wrapping(0xbf58476d1ce4e5b9); + z = (z ^ (z >> 27)) * Wrapping(0x94d049bb133111eb); + (z ^ (z >> 31)).0 } - pub fn merge(self) -> Vec<(String, W)> { - let mut stack = vec![]; - for (ngram, data) in &self.map { - if data.borrow().1 { - continue; - } - stack.push(data); - for (j, _) in ngram.char_indices().skip(1) { - if let Some(data) = self.map.get(&ngram[j..]) { - stack.push(data); - if data.borrow().1 { - break; - } - } - } - let mut data_from = stack.pop().unwrap(); - data_from.borrow_mut().1 = true; - while let Some(data_to) = stack.pop() { - let new_data = ( - W::from_two_weights(&data_from.borrow().0, &data_to.borrow().0, self.n_classes), - true, - ); - *data_to.borrow_mut() = new_data; - data_from = data_to; - } + #[inline(always)] + fn write(&mut self, bytes: &[u8]) { + for &i in bytes { + self.x ^= u64::from(i); + self.x += 0x9e3779b97f4a7c15; } - self.map - .into_iter() - .map(|(ngram, weight)| (ngram, weight.into_inner().0)) - .collect() } -} -pub struct VecWriter(pub Vec); + #[inline(always)] + fn write_u8(&mut self, i: u8) { + self.x ^= u64::from(i); + self.x += 0x9e3779b97f4a7c15; + } -impl Writer for VecWriter { - fn write(&mut self, bytes: &[u8]) -> Result<(), EncodeError> { - self.0.extend_from_slice(bytes); - Ok(()) + #[inline(always)] + fn write_u16(&mut self, i: u16) { + self.x ^= u64::from(i); + self.x += 0x9e3779b97f4a7c15; + } + + #[inline(always)] + fn write_u32(&mut self, i: u32) { + self.x ^= u64::from(i); + self.x += 0x9e3779b97f4a7c15; + } + + #[inline(always)] + fn write_u64(&mut self, i: u64) { + self.x ^= i; + self.x += 0x9e3779b97f4a7c15; + } + + #[inline(always)] + fn write_i8(&mut self, i: i8) { + self.x ^= i as u64; + self.x += 0x9e3779b97f4a7c15; + } + + #[inline(always)] + fn write_i16(&mut self, i: i16) { + self.x ^= i as u64; + self.x += 0x9e3779b97f4a7c15; + } + + #[inline(always)] + fn write_i32(&mut self, i: i32) { + self.x ^= i as u64; + self.x += 0x9e3779b97f4a7c15; + } + + #[inline(always)] + fn write_i64(&mut self, i: i64) { + self.x ^= i as u64; + self.x += 0x9e3779b97f4a7c15; } } -#[cfg(feature = "tag-prediction")] -pub fn xor_or_zip_with(lhs: &Option, rhs: &Option, f: F) -> Option -where - T: Clone, - F: FnOnce(&T, &T) -> T, -{ - lhs.as_ref().map_or_else( - || rhs.clone(), - |x1| Some(rhs.as_ref().map_or_else(|| x1.clone(), |x2| f(x1, x2))), - ) +#[derive(Clone, Copy, Default)] +pub struct SplitMix64Builder; + +impl BuildHasher for SplitMix64Builder { + type Hasher = SplitMix64; + + #[inline(always)] + fn build_hasher(&self) -> Self::Hasher { + SplitMix64 { x: Wrapping(0) } + } } #[cfg(feature = "kytea")] diff --git a/vaporetto_rules/Cargo.toml b/vaporetto_rules/Cargo.toml index b2520189..f59085cf 100644 --- a/vaporetto_rules/Cargo.toml +++ b/vaporetto_rules/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vaporetto_rules" -version = "0.4.0" +version = "0.5.0" edition = "2021" authors = ["Koichi Akabe "] description = "Rule-base filters for Vaporetto" @@ -10,7 +10,11 @@ repository = "https://github.com/daac-tools/vaporetto" readme = "README.md" keywords = ["japanese", "analyzer", "tokenizer", "morphological"] categories = ["text-processing", "no-std"] +resolver = "2" [dependencies] unicode-segmentation = "1.9.0" # MIT or Apache-2.0 -vaporetto = { path = "../vaporetto", version = "0.4.0", default-features = false } # MIT or Apache-2.0 +vaporetto = { path = "../vaporetto", version = "0.5.0", default-features = false, features = ["alloc"] } # MIT or Apache-2.0 + +[dev-dependencies] +vaporetto = { path = "../vaporetto", version = "0.5.0" } # MIT or Apache-2.0 diff --git a/vaporetto_rules/README.md b/vaporetto_rules/README.md index f8edbeac..3c80810e 100644 --- a/vaporetto_rules/README.md +++ b/vaporetto_rules/README.md @@ -19,31 +19,31 @@ use vaporetto_rules::{ let mut f = BufReader::new(File::open("model.bin").unwrap()); let model = Model::read(&mut f).unwrap(); -let mut predictor = Predictor::new(model).unwrap(); +let mut predictor = Predictor::new(model, false).unwrap(); -let pre_filters: Vec> = vec![ - Box::new(KyteaFullwidthFilter::new()), +let pre_filters: Vec>> = vec![ + Box::new(KyteaFullwidthFilter), ]; let post_filters: Vec> = vec![ - Box::new(ConcatGraphemeClustersFilter::new()), + Box::new(ConcatGraphemeClustersFilter), Box::new(KyteaWsConstFilter::new(CharacterType::Digit)), ]; let input = "Vaporettoは仲良し家族👨‍👨‍👧‍👦を離れ離れにさせません。" .to_string(); -let input = Rc::new(input); -let preproc_input = pre_filters.iter().fold(input, |s, filter| Rc::new(filter.filter(&s))); -let preproc_input = Rc::try_unwrap(preproc_input).unwrap(); +let preproc_input = pre_filters.iter().fold(input, |s, filter| filter.filter(s)); -let sentence = Sentence::from_raw(preproc_input).unwrap(); -let sentence = predictor.predict(sentence); +let mut sentence = Sentence::from_raw(preproc_input).unwrap(); +predictor.predict(&mut sentence); -let postproc_result = post_filters.iter().fold(sentence, |s, filter| filter.filter(s)); +post_filters.iter().for_each(|filter| filter.filter(&mut sentence)); +let mut buf = String::new(); +sentence.write_tokenized_text(&mut buf); assert_eq!( "Vaporetto は 仲良 し 家族 👨‍👨‍👧‍👦 を 離れ離れ に さ せ ま せ ん 。", - postproc_result.to_tokenized_string().unwrap(), + buf, ); ``` diff --git a/vaporetto_rules/src/lib.rs b/vaporetto_rules/src/lib.rs index 5167eca8..967ac53b 100644 --- a/vaporetto_rules/src/lib.rs +++ b/vaporetto_rules/src/lib.rs @@ -16,11 +16,11 @@ //! string_filters::KyteaFullwidthFilter, //! }; //! -//! let mut f = BufReader::new(File::open("model.bin").unwrap()); -//! let model = Model::read(&mut f).unwrap(); +//! let f = BufReader::new(File::open("model.bin").unwrap()); +//! let model = Model::read(f).unwrap(); //! let mut predictor = Predictor::new(model, false).unwrap(); //! -//! let pre_filters: Vec> = vec![ +//! let pre_filters: Vec>> = vec![ //! Box::new(KyteaFullwidthFilter), //! ]; //! let post_filters: Vec> = vec![ @@ -31,21 +31,20 @@ //! let input = "Vaporettoは仲良し家族👨‍👨‍👧‍👦を離れ離れにさせません。" //! .to_string(); //! -//! let input = Rc::new(input); -//! let preproc_input = pre_filters.iter().fold(input, |s, filter| Rc::new(filter.filter(&s))); -//! let preproc_input = Rc::try_unwrap(preproc_input).unwrap(); +//! let preproc_input = pre_filters.iter().fold(input, |s, filter| filter.filter(s)); //! -//! let sentence = Sentence::from_raw(preproc_input).unwrap(); -//! let sentence = predictor.predict(sentence); +//! let mut sentence = Sentence::from_raw(preproc_input).unwrap(); +//! predictor.predict(&mut sentence); //! -//! let postproc_result = post_filters.iter().fold(sentence, |s, filter| filter.filter(s)); +//! post_filters.iter().for_each(|filter| filter.filter(&mut sentence)); //! +//! let mut buf = String::new(); +//! sentence.write_tokenized_text(&mut buf); //! assert_eq!( //! "Vaporetto は 仲良 し 家族 👨‍👨‍👧‍👦 を 離れ離れ に さ せ ま せ ん 。", -//! postproc_result.to_tokenized_string().unwrap(), +//! buf, //! ); //! ``` -//! #![no_std] @@ -60,26 +59,13 @@ use vaporetto::Sentence; pub trait SentenceFilter: Send + Sync { /// Filter a specified sentence using rules. - /// - /// # Arguments: - /// - /// * `sentence` - Input sentence. - /// - /// # Returns - /// - /// A processed sentence. - fn filter(&self, sentence: Sentence) -> Sentence; + fn filter(&self, sentence: &mut Sentence); } -pub trait StringFilter: Send + Sync { +pub trait StringFilter: Send + Sync +where + S: AsRef, +{ /// Filter a specified string using rules. - /// - /// # Arguments: - /// - /// * `string` - Input string. - /// - /// # Returns - /// - /// A processed string. - fn filter(&self, string: &str) -> String; + fn filter(&self, string: S) -> String; } diff --git a/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs b/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs index 39d149b2..d83834f5 100644 --- a/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs +++ b/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs @@ -1,5 +1,5 @@ use unicode_segmentation::UnicodeSegmentation; -use vaporetto::{BoundaryType, Sentence}; +use vaporetto::{CharacterBoundary, Sentence}; use crate::SentenceFilter; @@ -8,15 +8,29 @@ use crate::SentenceFilter; pub struct ConcatGraphemeClustersFilter; impl SentenceFilter for ConcatGraphemeClustersFilter { - fn filter(&self, mut sentence: Sentence) -> Sentence { - let mut tmp = sentence.boundaries().to_vec(); - for (i, c) in sentence.to_raw_string().grapheme_indices(true) { - let start = sentence.get_char_pos(i).unwrap(); - let end = sentence.get_char_pos(i + c.len()).unwrap() - 1; - tmp[start..end].fill(BoundaryType::NotWordBoundary); + fn filter(&self, sentence: &mut Sentence) { + let mut start = 0; + let mut offset = 0; + unsafe { + debug_assert!(sentence.as_raw_text().is_char_boundary(offset)); + while let Some((len, n_chars)) = sentence + .as_raw_text() + .get_unchecked(offset..) + .graphemes(true) + .next() + .map(|x| (x.len(), x.chars().count())) + { + offset += len; + let end = start + n_chars; + debug_assert!(start <= sentence.boundaries().len()); + debug_assert!(end <= sentence.boundaries().len() + 1); + sentence + .boundaries_mut() + .get_unchecked_mut(start..end - 1) + .fill(CharacterBoundary::NotWordBoundary); + start = end; + } } - sentence.boundaries_mut().copy_from_slice(&tmp); - sentence } } @@ -24,42 +38,46 @@ impl SentenceFilter for ConcatGraphemeClustersFilter { mod tests { use super::*; + use alloc::string::String; + #[test] fn test_concat_grapheme_clusters_no_boundary() { - let s = Sentence::from_tokenized("\u{200d}").unwrap(); + let mut s = Sentence::from_tokenized("\u{200d}").unwrap(); let filter = ConcatGraphemeClustersFilter; - let s = filter.filter(s); - assert_eq!("\u{200d}", s.to_tokenized_string().unwrap()); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("\u{200d}", buf); } #[test] fn test_concat_grapheme_clusters_zwj() { - let s = + let mut s = Sentence::from_tokenized("\u{1f468} \u{200d} \u{1f469} \u{200d} \u{1f466}").unwrap(); let filter = ConcatGraphemeClustersFilter; - let s = filter.filter(s); - assert_eq!( - "\u{1f468}\u{200d}\u{1f469}\u{200d}\u{1f466}", - s.to_tokenized_string().unwrap() - ); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("\u{1f468}\u{200d}\u{1f469}\u{200d}\u{1f466}", buf); } #[test] fn test_concat_grapheme_clusters_color() { - let s = Sentence::from_tokenized("\u{1f44f} \u{1f3fd}").unwrap(); + let mut s = Sentence::from_tokenized("\u{1f44f} \u{1f3fd}").unwrap(); let filter = ConcatGraphemeClustersFilter; - let s = filter.filter(s); - assert_eq!("\u{1f44f}\u{1f3fd}", s.to_tokenized_string().unwrap()); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("\u{1f44f}\u{1f3fd}", buf); } #[test] fn test_concat_grapheme_clusters_combined() { - let s = Sentence::from_tokenized("これ は 手 \u{1f44f} \u{1f3fd} で す").unwrap(); + let mut s = Sentence::from_tokenized("これ は 手 \u{1f44f} \u{1f3fd} で す").unwrap(); let filter = ConcatGraphemeClustersFilter; - let s = filter.filter(s); - assert_eq!( - "これ は 手 \u{1f44f}\u{1f3fd} で す", - s.to_tokenized_string().unwrap() - ); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("これ は 手 \u{1f44f}\u{1f3fd} で す", buf); } } diff --git a/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs b/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs index 07d69964..fa72bd91 100644 --- a/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs +++ b/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs @@ -1,4 +1,4 @@ -use vaporetto::{BoundaryType, CharacterType, Sentence}; +use vaporetto::{CharacterBoundary, CharacterType, Sentence}; use crate::SentenceFilter; @@ -24,15 +24,22 @@ impl KyteaWsConstFilter { } impl SentenceFilter for KyteaWsConstFilter { - fn filter(&self, mut sentence: Sentence) -> Sentence { + fn filter(&self, sentence: &mut Sentence) { let t_flag = self.char_type as u8; - let (_, char_types, boundaries) = sentence.chars_and_boundaries_mut(); - for ((t1, t2), b) in char_types.iter().zip(&char_types[1..]).zip(boundaries) { - if *t1 == t_flag && *t2 == t_flag { - *b = BoundaryType::NotWordBoundary; + let len = sentence.char_types().len() - 1; + for i in 0..len { + debug_assert!(i < sentence.char_types().len()); + debug_assert!(i + 1 < sentence.char_types().len()); + unsafe { + if *sentence.char_types().get_unchecked(i) == t_flag + && *sentence.char_types().get_unchecked(i + 1) == t_flag + { + debug_assert!(i < sentence.boundaries().len()); + *sentence.boundaries_mut().get_unchecked_mut(i) = + CharacterBoundary::NotWordBoundary; + } } } - sentence } } @@ -40,27 +47,35 @@ impl SentenceFilter for KyteaWsConstFilter { mod tests { use super::*; + use alloc::string::String; + #[test] fn test_concat_cons_char_types_no_boundary() { - let s = Sentence::from_tokenized("5").unwrap(); + let mut s = Sentence::from_tokenized("5").unwrap(); let filter = KyteaWsConstFilter::new(CharacterType::Digit); - let s = filter.filter(s); - assert_eq!("5", s.to_tokenized_string().unwrap()); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("5", buf); } #[test] fn test_concat_cons_char_types() { - let s = Sentence::from_tokenized("5 00 0").unwrap(); + let mut s = Sentence::from_tokenized("5 00 0").unwrap(); let filter = KyteaWsConstFilter::new(CharacterType::Digit); - let s = filter.filter(s); - assert_eq!("5000", s.to_tokenized_string().unwrap()); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("5000", buf); } #[test] fn test_concat_cons_char_types_combined() { - let s = Sentence::from_tokenized("20 21 年 8 月 2 4 日").unwrap(); + let mut s = Sentence::from_tokenized("20 21 年 8 月 2 4 日").unwrap(); let filter = KyteaWsConstFilter::new(CharacterType::Digit); - let s = filter.filter(s); - assert_eq!("2021 年 8 月 24 日", s.to_tokenized_string().unwrap()); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("2021 年 8 月 24 日", buf); } } diff --git a/vaporetto_rules/src/sentence_filters/split_linebreaks.rs b/vaporetto_rules/src/sentence_filters/split_linebreaks.rs index 71156946..73ff34a2 100644 --- a/vaporetto_rules/src/sentence_filters/split_linebreaks.rs +++ b/vaporetto_rules/src/sentence_filters/split_linebreaks.rs @@ -1,4 +1,4 @@ -use vaporetto::{BoundaryType, Sentence}; +use vaporetto::{CharacterBoundary, Sentence}; use crate::SentenceFilter; @@ -7,17 +7,32 @@ use crate::SentenceFilter; pub struct SplitLinebreaksFilter; impl SentenceFilter for SplitLinebreaksFilter { - fn filter(&self, mut sentence: Sentence) -> Sentence { - let (chars, _, boundaries) = sentence.chars_and_boundaries_mut(); - for ((c1, c2), b) in chars.iter().zip(&chars[1..]).zip(boundaries) { - match (*c1, *c2) { - ('\r' | '\n', _) | (_, '\r' | '\n') => { - *b = BoundaryType::WordBoundary; + fn filter(&self, sentence: &mut Sentence) { + unsafe { + debug_assert!(!sentence.as_raw_text().is_empty()); + let mut prev_c = sentence.as_raw_text().chars().next().unwrap_unchecked(); + let mut offset = prev_c.len_utf8(); + let mut i = 0; + debug_assert!(sentence.as_raw_text().is_char_boundary(offset)); + while let Some(c) = sentence + .as_raw_text() + .get_unchecked(offset..) + .chars() + .next() + { + offset += c.len_utf8(); + match (prev_c, c) { + ('\r' | '\n', _) | (_, '\r' | '\n') => { + debug_assert!(i < sentence.boundaries().len()); + *sentence.boundaries_mut().get_unchecked_mut(i) = + CharacterBoundary::WordBoundary; + } + _ => {} } - _ => {} + prev_c = c; + i += 1; } } - sentence } } @@ -25,27 +40,35 @@ impl SentenceFilter for SplitLinebreaksFilter { mod tests { use super::*; + use alloc::string::String; + #[test] fn test_split_lf() { - let s = Sentence::from_tokenized("前の行\n次の行").unwrap(); + let mut s = Sentence::from_tokenized("前の行\n次の行").unwrap(); let filter = SplitLinebreaksFilter; - let s = filter.filter(s); - assert_eq!("前の行 \n 次の行", s.to_tokenized_string().unwrap()); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("前の行 \n 次の行", buf); } #[test] fn test_split_cr() { - let s = Sentence::from_tokenized("前の行\r次の行").unwrap(); + let mut s = Sentence::from_tokenized("前の行\r次の行").unwrap(); let filter = SplitLinebreaksFilter; - let s = filter.filter(s); - assert_eq!("前の行 \r 次の行", s.to_tokenized_string().unwrap()); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("前の行 \r 次の行", buf); } #[test] fn test_split_crlf() { - let s = Sentence::from_tokenized("前の行\r\n次の行").unwrap(); + let mut s = Sentence::from_tokenized("前の行\r\n次の行").unwrap(); let filter = SplitLinebreaksFilter; - let s = filter.filter(s); - assert_eq!("前の行 \r \n 次の行", s.to_tokenized_string().unwrap()); + filter.filter(&mut s); + let mut buf = String::new(); + s.write_tokenized_text(&mut buf); + assert_eq!("前の行 \r \n 次の行", buf); } } diff --git a/vaporetto_rules/src/string_filters.rs b/vaporetto_rules/src/string_filters.rs index 5e27109a..e8705184 100644 --- a/vaporetto_rules/src/string_filters.rs +++ b/vaporetto_rules/src/string_filters.rs @@ -1,4 +1,4 @@ -//! Filters for [`String`]. +//! Filters for [`String`](alloc::string::String). mod kytea_fullwidth; diff --git a/vaporetto_rules/src/string_filters/kytea_fullwidth.rs b/vaporetto_rules/src/string_filters/kytea_fullwidth.rs index 61ff166e..07311c88 100644 --- a/vaporetto_rules/src/string_filters/kytea_fullwidth.rs +++ b/vaporetto_rules/src/string_filters/kytea_fullwidth.rs @@ -1,5 +1,4 @@ use alloc::string::String; -use alloc::vec::Vec; use crate::StringFilter; @@ -7,11 +6,14 @@ use crate::StringFilter; #[derive(Clone, Default)] pub struct KyteaFullwidthFilter; -impl StringFilter for KyteaFullwidthFilter { - fn filter(&self, string: &str) -> String { - let mut chars: Vec<_> = string.chars().collect(); - for c in &mut chars { - *c = match *c { +impl StringFilter for KyteaFullwidthFilter +where + S: AsRef, +{ + fn filter(&self, string: S) -> String { + let mut result = String::new(); + for c in string.as_ref().chars() { + result.push(match c { 'a' => 'a', 'b' => 'b', 'c' => 'c', @@ -109,8 +111,8 @@ impl StringFilter for KyteaFullwidthFilter { '@' => '@', '=' => '=', c => c, - }; + }); } - chars.into_iter().collect() + result } } diff --git a/vaporetto_tantivy/Cargo.toml b/vaporetto_tantivy/Cargo.toml index d70d0fa6..0597638e 100644 --- a/vaporetto_tantivy/Cargo.toml +++ b/vaporetto_tantivy/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vaporetto_tantivy" -version = "0.4.0" +version = "0.5.0" edition = "2021" authors = ["Koichi Akabe "] description = "Vaporetto Tokenizer for Tantivy" @@ -12,9 +12,9 @@ keywords = ["japanese", "tokenizer", "tantivy"] categories = ["text-processing"] [dependencies] -vaporetto = { path = "../vaporetto", version = "0.4.0" } # MIT or Apache-2.0 -vaporetto_rules = { path = "../vaporetto_rules", version = "0.4.0" } # MIT or Apache-2.0 -tantivy = "0.17" # MIT +vaporetto = { path = "../vaporetto", version = "0.5.0" } # MIT or Apache-2.0 +vaporetto_rules = { path = "../vaporetto_rules", version = "0.5.0" } # MIT or Apache-2.0 +tantivy = "0.18" # MIT [dev-dependencies] ruzstd = "0.2.4" # MIT diff --git a/vaporetto_tantivy/src/lib.rs b/vaporetto_tantivy/src/lib.rs index ec2f375e..19b4b926 100644 --- a/vaporetto_tantivy/src/lib.rs +++ b/vaporetto_tantivy/src/lib.rs @@ -51,7 +51,7 @@ use std::sync::Arc; use tantivy::tokenizer::{BoxTokenStream, Token, TokenStream, Tokenizer}; -use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence}; +use vaporetto::{CharacterBoundary, CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter, SplitLinebreaksFilter}, string_filters::KyteaFullwidthFilter, @@ -125,22 +125,21 @@ impl Tokenizer for VaporettoTokenizer { // pre filter let prefiltered_text = self.prefilter.filter(text); - let prefiltered_sentence = Sentence::from_raw(prefiltered_text).unwrap(); + let mut s = Sentence::from_raw(prefiltered_text).unwrap(); // tokenize - let tokenized_sentence = self.predictor.predict(prefiltered_sentence); + self.predictor.predict(&mut s); // post filter - let postfiltered_sentence = self - .postfilters + self.postfilters .iter() - .fold(tokenized_sentence, |s, filter| filter.filter(s)); + .for_each(|filter| filter.filter(&mut s)); let mut char_indices = text.char_indices(); char_indices.next(); - let mut boundary_pos = Vec::with_capacity(postfiltered_sentence.chars().len()); - for ((i, _), &b) in char_indices.zip(postfiltered_sentence.boundaries()) { - if b == BoundaryType::WordBoundary { + let mut boundary_pos = Vec::with_capacity(s.boundaries().len() + 1); + for ((i, _), &b) in char_indices.zip(s.boundaries()) { + if b == CharacterBoundary::WordBoundary { boundary_pos.push(i); } } diff --git a/vaporetto_tantivy/test_model/model.zst b/vaporetto_tantivy/test_model/model.zst index 538467ae9830a2a5714e11919c6ae5dc6e889017..11a7809d0e953db151b8bb150bc1126886a74244 100644 GIT binary patch literal 332 zcmV-S0ki%nwJ-euXdMaw)=)n=Ft5@WpqLycJNmycNFhGhX3{&=>a(h=3~Kmqs6-=l z4YK;96vG3j%zTU{r9j!c4K@h2?#-c}0yF?E051T03SEw>YTZ<0_rP}9I?+}x*_?B+ ze=hrT<&wpLfDkAF6#7P@jEM>n1^57xNTLyFApsGs6Rm@JW3n$-E*VYDrU1B4nws76 zEYTW&1(scwXk~?U{s1bwhq03t);K2?4$G||?4RX)568}xXzlKxRSc>wjJFsD&(29H32wWT zQe%L!e>D5gOeX%DA1F{rK; exY@yoX1%orHEt_(E4Y!oP_)im z+!xC2^+Hh`2