Skip to content

Commit

Permalink
Reimplement Vaporetto with supporting multiple tags (#35)
Browse files Browse the repository at this point in the history
* Reimplement Vaporetto with supporting multiple tags

* Update examples

* fmt all

* clippy and fix bugs

* fmt

* fix bugs

* Add charwise-pma feature

* fmt

* clippy

* fix

* fix features

* docs

* Implement Eq to CharacterBoundary

* fix features

* fix CI

* Update README

* update README-ja

* fix a bug

* fix README

* Update vaporetto/src/dict_model.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update vaporetto/src/utils.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update vaporetto/src/dict_model.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update docs of Sentence

* update docs

* Add a description of Sentence::default()

* Update docs

* fix docs

* fmt

* fix minor

* fix lint

* fix doc

* Update vaporetto/src/sentence.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update vaporetto/src/sentence.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update vaporetto/src/sentence.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update vaporetto/src/sentence.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update vaporetto/src/sentence.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update vaporetto/src/sentence.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* fix docs

* Update vaporetto/src/sentence.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* fix docs

* update docs

* str_to_char_pos

* fix

* Update vaporetto/src/sentence.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* Update sentence.rs

* debug_assert() for daachorse

* Update vaporetto/src/char_scorer.rs

Co-authored-by: Shunsuke Kanda <[email protected]>

* fix

* fix arg names

* fix

* fix

* check window_size

* fix

* Add Nop

* Revert "Add Nop"

This reverts commit 3b8d9d9.

* Revert "fix"

This reverts commit ccf3657.

* Revert "check window_size"

This reverts commit d17062e.

* fix serialization

* fix

* wrap tag_predictor by Option

* add panic test

* add comments for unsafe blocks

* fix

* add trim_end_zeros()

* refactoring

* add TagWeight

* Add description of TagModel

* Update model.rs

* Use HashMap in model

* fix

* tantivy 0.18

* fix

* Add wrapper of HashMap for bincode

* fix

* fix

* fix

* fix

* fix

* add cfg annotation

* fix

* fix

* Use Vec in Model

* Apply suggestions from code review

Co-authored-by: Shunsuke Kanda <[email protected]>

* Add debug_assert!()

* Remove redundant clone()

* fix

* fix

* Apply suggestions from code review

Co-authored-by: Shunsuke Kanda <[email protected]>

Co-authored-by: Shunsuke Kanda <[email protected]>
  • Loading branch information
vbkaisetsu and kampersanda authored Jun 6, 2022
1 parent 77c76dc commit 58ef5f3
Show file tree
Hide file tree
Showing 45 changed files with 5,477 additions and 4,648 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: clippy
args: -- -D warnings -W clippy::nursery -W clippy::cast_lossless -W clippy::cast_possible_truncation -W clippy::cast_possible_wrap
args: -- -D warnings -W clippy::nursery -W clippy::cast_lossless -W clippy::cast_possible_truncation -W clippy::cast_possible_wrap -A clippy::empty_line_after_outer_attr

- name: Run cargo test (workspace)
uses: actions-rs/cargo@v1
Expand All @@ -46,7 +46,7 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --release -p vaporetto --no-default-features
args: --release -p vaporetto --no-default-features --features alloc

- name: Run cargo test (vaporetto / features kytea)
uses: actions-rs/cargo@v1
Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --release -p vaporetto --no-default-features --features charwise-daachorse
args: --release -p vaporetto --no-default-features --features charwise-pma

- name: Run cargo test (vaporetto / features std)
uses: actions-rs/cargo@v1
Expand Down Expand Up @@ -117,7 +117,7 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: clippy
args: -- -D warnings -W clippy::nursery -W clippy::cast_lossless -W clippy::cast_possible_truncation -W clippy::cast_possible_wrap
args: -- -D warnings -W clippy::nursery -W clippy::cast_lossless -W clippy::cast_possible_truncation -W clippy::cast_possible_wrap -A clippy::empty_line_after_outer_attr

- name: Run cargo test (workspace)
uses: actions-rs/cargo@v1
Expand All @@ -129,7 +129,7 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --release -p vaporetto --no-default-features
args: --release -p vaporetto --no-default-features --features alloc

- name: Run cargo test (vaporetto / all-features)
uses: actions-rs/cargo@v1
Expand Down Expand Up @@ -171,7 +171,7 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --release -p vaporetto --no-default-features --features charwise-daachorse
args: --release -p vaporetto --no-default-features --features charwise-pma

- name: Run cargo test (vaporetto / features std)
uses: actions-rs/cargo@v1
Expand Down
10 changes: 5 additions & 5 deletions README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,23 +187,23 @@ Vaporetto は2種類のコーパス、すなわちフルアノテーションコ

### 品詞推定

Vaporettoは実験的に品詞推定に対応しています
Vaporettoは実験的にタグ推定(品詞推定や読み推定)に対応しています

品詞を学習するには、以下のように、データセットの各トークンに続けてスラッシュと品詞を追加します
タグを学習するには、以下のように、データセットの各トークンに続けてスラッシュとタグを追加します

* フルアノテーションコーパスの場合
```
この/連体詞 人/名詞 は/助詞 火星/名詞 人/接尾辞 です/助動詞
この/連体詞/コノ 人/名詞/ヒト は/助詞/ワ 火星/名詞/カセイ 人/接尾辞/ジン です/助動詞/デス
```

* 部分アノテーションコーパスの場合
```
ヴ-ェ-ネ-ツ-ィ-ア/名詞|は/助詞|イ-タ-リ-ア/名詞|に/助詞|あ-り ま-す
```

データセットに品詞が含まれる場合`train` コマンドは自動的にそれらを学習します。
データセットにタグが含まれる場合`train` コマンドは自動的にそれらを学習します。

推定時は、デフォルトでは品詞は推定されないため、必要に応じで `predict` コマンドに `--predict-tags` 引数を指定してください。
推定時は、デフォルトではタグは推定されないため、必要に応じで `predict` コマンドに `--predict-tags` 引数を指定してください。

## 各種トークナイザの速度比較

Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ Now `外国人参政権` is split into correct tokens.
9:交代 -7658
```

### POS tagging
### Tagging

Vaporetto experimentally supports POS tagging.
Vaporetto experimentally supports tagging (e.g., part-of-speech and pronunciation tags).

To train tags, add a slash and tag name following each token in the dataset as follows:
To train tags, add slashes and tags following each token in the dataset as follows:

* For fully annotated corpora
```
この/連体詞 人/名詞 は/助詞 火星/名詞 人/接尾辞 です/助動詞
この/連体詞/コノ 人/名詞/ヒト は/助詞/ワ 火星/名詞/カセイ 人/接尾辞/ジン です/助動詞/デス
```

* For partially annotated corpora
Expand Down
45 changes: 26 additions & 19 deletions evaluate/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::path::PathBuf;
use std::str::FromStr;

use clap::Parser;
use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence};
use vaporetto::{CharacterBoundary, CharacterType, Model, Predictor, Sentence};
use vaporetto_rules::{
sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter},
string_filters::KyteaFullwidthFilter,
Expand Down Expand Up @@ -107,19 +107,25 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
if line.is_empty() {
continue;
}
let mut s = Sentence::from_tokenized(line)?;
let mut s = Sentence::from_tokenized(&line)?;
let ref_boundaries = s.boundaries().to_vec();
let ref_tags = s.tags().to_vec();
let mut ref_tags = vec![];
for i in 0..=ref_boundaries.len() {
ref_tags.push(s.tags()[i * s.n_tags()..(i + 1) * s.n_tags()].to_vec());
}
if !args.no_norm {
let new_line = fullwidth_filter.filter(s.to_raw_string());
let new_line = fullwidth_filter.filter(s.as_raw_text());
s = Sentence::from_raw(new_line)?
};
s = predictor.predict(s);
s = post_filters.iter().fold(s, |s, filter| filter.filter(s));
s = predictor.fill_tags(s);
let hyp_boundaries = s.boundaries().to_vec();
let hyp_tags = s.tags().to_vec();
results.push((ref_boundaries, ref_tags, hyp_boundaries, hyp_tags));
predictor.predict(&mut s);
post_filters.iter().for_each(|filter| filter.filter(&mut s));
s.fill_tags();
let sys_boundaries = s.boundaries().to_vec();
let mut sys_tags = vec![];
for i in 0..=sys_boundaries.len() {
sys_tags.push(s.tags()[i * s.n_tags()..(i + 1) * s.n_tags()].to_vec());
}
results.push((ref_boundaries, ref_tags, sys_boundaries, sys_tags));
}

match args.metric {
Expand All @@ -131,12 +137,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
for (rs_b, _, hs_b, _) in results {
for (r, h) in rs_b.into_iter().zip(hs_b) {
if r == h {
if h == BoundaryType::WordBoundary {
if h == CharacterBoundary::WordBoundary {
n_tp += 1;
} else {
n_tn += 1;
}
} else if h == BoundaryType::WordBoundary {
} else if h == CharacterBoundary::WordBoundary {
n_fp += 1;
} else {
n_fn += 1;
Expand All @@ -159,28 +165,29 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut n_sys = 0;
let mut n_ref = 0;
let mut n_cor = 0;
for (rs_b, rs_t, hs_b, hs_t) in results {
for (refs_b, refs_t, syss_b, syss_t) in results {
let mut matched = true;
for (((r_b, r_t), h_b), h_t) in rs_b.iter().zip(&rs_t).zip(&hs_b).zip(&hs_t) {
if r_b == h_b {
if *h_b == BoundaryType::WordBoundary {
if matched && r_t == h_t {
for (((r_b, r_t), s_b), s_t) in refs_b.iter().zip(&refs_t).zip(&syss_b).zip(&syss_t)
{
if r_b == s_b {
if *s_b == CharacterBoundary::WordBoundary {
if matched && r_t == s_t {
n_cor += 1;
}
matched = true;
n_ref += 1;
n_sys += 1;
}
} else {
if *h_b == BoundaryType::WordBoundary {
if *s_b == CharacterBoundary::WordBoundary {
n_sys += 1;
} else {
n_ref += 1;
}
matched = false;
}
}
if matched && rs_t.last().unwrap() == hs_t.last().unwrap() {
if matched && refs_t.last().unwrap() == syss_t.last().unwrap() {
n_cor += 1;
}
n_sys += 1;
Expand Down
2 changes: 1 addition & 1 deletion examples/embedded_device/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ vaporetto_rules = { path = "../../vaporetto_rules" }
alloc-cortex-m = "0.4.0"

[build-dependencies]
vaporetto = { path = "../../vaporetto", default-features = false }
vaporetto = { path = "../../vaporetto", default-features = false, features = ["alloc"] }
ruzstd = "0.2.4" # MIT

[profile.release]
Expand Down
28 changes: 13 additions & 15 deletions examples/embedded_device/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ extern crate alloc;
use core::alloc::Layout;

// alloc crate
use alloc::vec::Vec;
use alloc::string::String;

// devices
use alloc_cortex_m::CortexMHeap;
Expand All @@ -17,11 +17,8 @@ use cortex_m_rt::entry;
use cortex_m_semihosting::hprintln;

// other crates
use vaporetto::{Predictor, Sentence, CharacterType};
use vaporetto_rules::{
sentence_filters::KyteaWsConstFilter,
SentenceFilter,
};
use vaporetto::{CharacterType, Predictor, Sentence};
use vaporetto_rules::{sentence_filters::KyteaWsConstFilter, SentenceFilter};

// panic behaviour
use panic_halt as _;
Expand All @@ -36,22 +33,23 @@ fn main() -> ! {
unsafe { ALLOCATOR.init(cortex_m_rt::heap_start() as usize, HEAP_SIZE) }

let predictor_data = include_bytes!(concat!(env!("OUT_DIR"), "/predictor.bin"));
let (predictor, _) = unsafe { Predictor::deserialize_from_slice_unchecked(predictor_data) }.unwrap();
let (predictor, _) =
unsafe { Predictor::deserialize_from_slice_unchecked(predictor_data) }.unwrap();

let docs = &[
"🚤VaporettoはSTM32F303VCT6(FLASH:256KiB,RAM:40KiB)などの小さなデバイスでも動作します",
];
let docs =
&["🚤VaporettoはSTM32F303VCT6(FLASH:256KiB,RAM:40KiB)などの小さなデバイスでも動作します"];

let wsconst_d_filter = KyteaWsConstFilter::new(CharacterType::Digit);

loop {
for &text in docs {
hprintln!("\x1b[32mINPUT:\x1b[m {:?}", text).unwrap();
let s = Sentence::from_raw(text).unwrap();
let s = predictor.predict(s);
let s = wsconst_d_filter.filter(s);
let v = s.to_tokenized_vec().unwrap().iter().map(|t| t.surface).collect::<Vec<_>>();
hprintln!("\x1b[31mOUTPUT:\x1b[m {:?}", v).unwrap();
let mut s = Sentence::from_raw(text).unwrap();
predictor.predict(&mut s);
wsconst_d_filter.filter(&mut s);
let mut buf = String::new();
s.write_tokenized_text(&mut buf);
hprintln!("\x1b[31mOUTPUT:\x1b[m {}", buf).unwrap();
}
}
}
Expand Down
49 changes: 12 additions & 37 deletions examples/wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::io::{Cursor, Read};

use js_sys::{Array, Object};
use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence};
use vaporetto::{CharacterBoundary, CharacterType, Model, Predictor, Sentence};
use vaporetto_rules::{
sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter},
string_filters::KyteaFullwidthFilter,
Expand Down Expand Up @@ -61,22 +61,19 @@ impl Vaporetto {
return result.into();
};
let norm = self.fullwidth_filter.filter(text);
let s_norm = if let Ok(s) = Sentence::from_raw(norm) {
let mut s_norm = if let Ok(s) = Sentence::from_raw(norm) {
s
} else {
return result.into();
};
let s_norm = self.predictor.predict(s_norm);
self.predictor.predict(&mut s_norm);
s.boundaries_mut().clone_from_slice(s_norm.boundaries());
let s = self
.post_filters
self.post_filters
.iter()
.fold(s, |s, filter| filter.filter(s));
.for_each(|filter| filter.filter(&mut s));

if let Ok(tokens) = s.to_tokenized_vec() {
for token in tokens {
result.push(&JsValue::from_str(token.surface));
}
for token in s.iter_tokens() {
result.push(&JsValue::from_str(token.surface()));
}
result.into()
}
Expand All @@ -85,41 +82,19 @@ impl Vaporetto {
pub fn predict(&self, text: &str) -> Object {
let result = Array::new();
let text = self.fullwidth_filter.filter(text);
let s = if let Ok(s) = Sentence::from_raw(text) {
s
} else {
return result.into();
};
let s = self.predictor.predict(s);
let s = self
.post_filters
.iter()
.fold(s, |s, filter| filter.filter(s));

for &b in s.boundaries() {
result.push(&JsValue::from_bool(b == BoundaryType::WordBoundary));
}
result.into()
}

#[wasm_bindgen]
pub fn predict_with_score(&self, text: &str) -> Object {
let result = Array::new();
let text = self.fullwidth_filter.filter(text);
let s = if let Ok(s) = Sentence::from_raw(text) {
let mut s = if let Ok(s) = Sentence::from_raw(text) {
s
} else {
return result.into();
};
let s = self.predictor.predict_with_score(s);
let s = self
.post_filters
self.predictor.predict(&mut s);
self.post_filters
.iter()
.fold(s, |s, filter| filter.filter(s));
.for_each(|filter| filter.filter(&mut s));

for (&score, &b) in s.boundary_scores().iter().zip(s.boundaries()) {
let boundary = Array::new();
boundary.push(&(b == BoundaryType::WordBoundary).into());
boundary.push(&(b == CharacterBoundary::WordBoundary).into());
boundary.push(&score.into());
result.push(&boundary);
}
Expand Down
2 changes: 1 addition & 1 deletion examples/wasm/www/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ vaporetto_bccwj_suw_small().then((Vaporetto) => {

input_text.addEventListener("input", (e) => {
const text = input_text.value;
const scores = vaporetto_suw.predict_with_score(text);
const scores = vaporetto_suw.predict(text);
let i = -1;
while (tokenized.firstChild) {
tokenized.removeChild(tokenized.firstChild);
Expand Down
Loading

0 comments on commit 58ef5f3

Please sign in to comment.