diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 0e735a18..3e0dbc1b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,8 +1,8 @@ on: push: - branches: [ main ] + branches: [ main, develop ] pull_request: - branches: [ main ] + branches: [ main, develop ] name: build @@ -48,11 +48,11 @@ jobs: command: test args: --release -p vaporetto --no-default-features - - name: Run cargo test (vaporetto / all-features) + - name: Run cargo test (vaporetto / features kytea+train) uses: actions-rs/cargo@v1 with: command: test - args: --release -p vaporetto --all-features + args: --release -p vaporetto --features kytea,train nightly: name: Nightly diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index c2e05dab..00000000 --- a/.gitmodules +++ /dev/null @@ -1,12 +0,0 @@ -[submodule "bench/kytea"] - path = bench/kytea - url = https://github.com/neubig/kytea.git -[submodule "bench/lindera"] - path = bench/lindera - url = https://github.com/lindera-morphology/lindera.git -[submodule "bench/mecab"] - path = bench/mecab - url = https://github.com/taku910/mecab.git -[submodule "bench/sudachi.rs"] - path = bench/sudachi.rs - url = https://github.com/WorksApplications/sudachi.rs.git diff --git a/Cargo.toml b/Cargo.toml index a3ba8aa0..0a60007c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,9 +3,14 @@ members = [ "vaporetto", "vaporetto_rules", + "vaporetto_tantivy", + "manipulate_model", "predict", "train", "evaluate", "convert_kytea_model", +] + +exclude = [ "vaporetto_wasm", ] diff --git a/README-ja.md b/README-ja.md new file mode 100644 index 00000000..cd4f79e5 --- /dev/null +++ b/README-ja.md @@ -0,0 +1,217 @@ +# 🛥 VAporetto: POintwise pREdicTion based TOkenizer + +Vaporetto は、高速で軽量な点予測に基づくトークナイザです。 +このリポジトリには、 Vaporetto の API を提供する Rust のクレートと、 CLI フロントエンドが含まれています。 + +[![Crates.io](https://img.shields.io/crates/v/vaporetto)](https://crates.io/crates/vaporetto) +[![Documentation](https://docs.rs/vaporetto/badge.svg)](https://docs.rs/vaporetto) +![Build Status](https://github.com/legalforce-research/vaporetto/actions/workflows/rust.yml/badge.svg) + +[技術解説](https://tech.legalforce.co.jp/entry/2021/09/28/180844) + +[English document](README.md) + +## 使用例 + +### トークン化を試す + +このソフトウェアは Rust で実装されています。事前に[ドキュメント](https://www.rust-lang.org/tools/install)に従って `rustc` と `cargo` をインストールしてください。 + +Vaporetto はトークン化モデルを生成するための方法を3つ用意しています。 + +#### 配布モデルをダウンロードする + +1番目は最も単純な方法で、我々によって学習されたモデルをダウンロードすることです。 +モデルファイルは[ここ](https://github.com/legalforce-research/vaporetto/releases/tag/v0.3.0)にあります。 + +`bccwj-suw+unidic+tag` を選びました。 +``` +% wget https://github.com/legalforce-research/vaporetto/releases/download/v0.3.0/bccwj-suw+unidic+tag.tar.xz +``` + +各ファイルにはモデルファイルとライセンス条項が含まれているので、以下のようなコマンドでダウンロードしたファイルを展開する必要があります。 +``` +% tar xf ./bccwj-suw+unidic+tag.tar.xz +``` + +トークン化を行うには、以下のコマンドを実行します。 +``` +% echo 'ヴェネツィアはイタリアにあります。' | cargo run --release -p predict -- --model path/to/bccwj-suw+unidic+tag.model.zst +``` + +以下が出力されるでしょう。 +``` +ヴェネツィア は イタリア に あり ます 。 +``` + +#### KyTea のモデルを変換する + +2番目の方法も単純で、 KyTea で学習されたモデルを変換することです。 +まずはじめに、好きなモデルを [KyTea Models](http://www.phontron.com/kytea/model.html) ページからダウンロードします。 + +`jp-0.4.7-5.mod.gz` を選びました。 +``` +% wget http://www.phontron.com/kytea/download/model/jp-0.4.7-5.mod.gz +``` + +各モデルは圧縮されているので、以下のようなコマンドでダウンロードしたモデルを展開する必要があります。 +``` +% gunzip ./jp-0.4.7-5.mod.gz +``` + +KyTea のモデルを Vaporetto のモデルに変換するには、 Vaporetto のルートディレクトリで以下のコマンドを実行します。 +``` +% cargo run --release -p convert_kytea_model -- --model-in path/to/jp-0.4.7-5.mod --model-out path/to/jp-0.4.7-5-tokenize.model.zst +``` + +これでトークン化を行えます。以下のコマンドを実行します。 +``` +% echo 'ヴェネツィアはイタリアにあります。' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize.model.zst +``` + +以下が出力されるでしょう。 +``` +ヴェネツィア は イタリア に あ り ま す 。 +``` + +#### 自分のモデルを学習する + +3番目は主に研究者向けで、自分で学習コーパスを用意し、自分でトークン化モデルを学習することです。 + +Vaporetto は2種類のコーパス、すなわちフルアノテーションコーパスと部分アノテーションコーパスから学習することが可能です。 + +フルアノテーションコーパスは、すべての文字境界に対してトークン境界であるかトークンの内部であるかがアノテーションされたコーパスです。 +このデータは、以下に示すようにトークン境界に空白が挿入された形式です。 + +``` +ヴェネツィア は イタリア に あり ます 。 +火星 猫 の 生態 の 調査 結果 +``` + +一方、部分アノテーションコーパスは一部の文字境界のみに対してアノテーションされたコーパスです。 +各文字境界には `|` (トークン境界)、 `-` (非トークン境界)、 ` ` (不明) のいずれかの形式でアノテーションされます。 + +ここに例を示します。 +``` +ヴ-ェ-ネ-ツ-ィ-ア|は|イ-タ-リ-ア|に|あ り ま す|。 +火-星 猫|の|生-態|の|調-査 結-果 +``` + +モデルを学習するには、以下のコマンドを使用します。 +``` +% cargo run --release -p train -- --model ./your.model.zst --tok path/to/full.txt --part path/to/part.txt --dict path/to/dict.txt +``` + +`--tok` 引数ではフルアノテーションコーパスを指定し、 `--part` 引数では部分アノテーションコーパスを指定します。 +`--dict` 引数によって単語辞書を指定することもできます。 +単語辞書は、1行1単語のファイルです。 + +学習器は空行の入力を受け付けません。 +このため、学習の前にコーパスから空行を削除してください。 + +上記の引数は複数回指定することが可能です。 + +### モデルの編集 + +時々、モデルが期待とは異なる結果を出力することがあるでしょう。 +例えば、以下のコマンドで `メロンパン` は2つのトークンに分割されます。 +`--scores` オプションを使って、各文字間のスコアを出力します。 +``` +% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize.model.zst +朝食 は メロン パン 1 個 だっ た +0:朝食 -15398 +1:食は 24623 +2:はメ 30261 +3:メロ -26885 +4:ロン -38896 +5:ンパ 8162 +6:パン -23416 +7:ン1 23513 +8:1個 18435 +9:個だ 24964 +10:だっ -15065 +11:った 14178 +``` + +`メロンパン` を単一のトークンに連結するには、以下の手順でモデルを編集し、 `ンパ` のスコアを負にします。 + +1. 以下のコマンドで辞書を吐き出します。 + ``` + % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --dump-dict path/to/dictionary.csv + ``` + +2. 辞書を編集します。 + + 辞書は CSV ファイルです。各行には単語と、対応する重みとコメントが以下の順で含まれています。 + + * `right_weight` - 単語が境界の右側に見つかった際に追加される重み。 + * `inside_weight` - 単語が境界に重なっている際に追加される重み。 + * `left_weight` - 単語が境界の左側に見つかった際に追加される重み。 + * `comment` - 挙動に影響しないコメント + + Vaporetto は、重みの合計が正の値になった際にテキストを分割するので、以下のように新しいエントリを追加します。 + ```diff + メロレオストーシス,6944,-2553,5319, + メロン,8924,-10861,7081, + +メロンパン,0,-100000,0,melon🍈 bread🍞 in English. + メロン果実,4168,-1165,3558, + メロヴィング,6999,-15413,7583, + ``` + + この場合、境界が `メロンパン` の内側だった際に `-100000` が追加されます。 + + Vaporetto は重みの合計値に 32-bit 整数を利用しているため、オーバーフローに気をつけてください。 + + 加えて、辞書には重複する単語を含めることができません。 + 単語が既に辞書に含まれている際は、既存の重みを編集する必要があります。 + +3. モデルファイルの重みを置換します。 + ``` + % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --replace-dict path/to/dictionary.csv --model-out path/to/jp-0.4.7-5-tokenize-new.model.zst + ``` + +これで `メロンパン` が単一のトークンに分割されます。 +``` +% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize-new.model.zst +朝食 は メロンパン 1 個 だっ た +0:朝食 -15398 +1:食は 24623 +2:はメ 30261 +3:メロ -126885 +4:ロン -138896 +5:ンパ -91838 +6:パン -123416 +7:ン1 23513 +8:1個 18435 +9:個だ 24964 +10:だっ -15065 +11:った 14178 +``` + +### 品詞推定 + +Vaporettoは実験的に品詞推定に対応しています。 + +品詞を学習するには、以下のように、データセットの各トークンに続けてスラッシュと品詞を追加します。 + +* フルアノテーションコーパスの場合 + ``` + この/連体詞 人/名詞 は/助詞 火星/名詞 人/接尾辞 です/助動詞 + ``` + +* 部分アノテーションコーパスの場合 + ``` + ヴ-ェ-ネ-ツ-ィ-ア/名詞|は/助詞|イ-タ-リ-ア/名詞|に/助詞|あ-り ま-す + ``` + +データセットに品詞が含まれる場合、 `train` コマンドは自動的にそれらを学習します。 + +推定時は、デフォルトでは品詞は推定されないため、必要に応じで `predict` コマンドに `--predict-tags` 引数を指定してください。 + +## 各種トークナイザの速度比較 + +Vaporetto は KyTea に比べて 8.25 倍速く動作します。 + +詳細は[ここ](https://github.com/legalforce-research/vaporetto/wiki/Speed-Comparison)を参照してください。 + +![](./figures/comparison.svg) diff --git a/README.md b/README.md index c4f49838..2e488453 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # 🛥 VAporetto: POintwise pREdicTion based TOkenizer Vaporetto is a fast and lightweight pointwise prediction based tokenizer. +This repository includes both a Rust crate that provides APIs for Vaporetto and CLI frontends. [![Crates.io](https://img.shields.io/crates/v/vaporetto)](https://crates.io/crates/vaporetto) [![Documentation](https://docs.rs/vaporetto/badge.svg)](https://docs.rs/vaporetto) @@ -8,19 +9,44 @@ Vaporetto is a fast and lightweight pointwise prediction based tokenizer. [Technical details](https://tech.legalforce.co.jp/entry/2021/09/28/180844) (Japanese) -## Overview +[日本語のドキュメント](README-ja.md) -This repository includes both a Rust crate that provides APIs for Vaporetto and CLI frontends. +## Example Usage ### Try Word Segmentation This software is implemented in Rust. Install `rustc` and `cargo` following [the documentation](https://www.rust-lang.org/tools/install) beforehand. -Vaporetto provides two ways to generate tokenization models: +Vaporetto provides three ways to generate tokenization models: + +#### Download Distribution Model + +The first is the simplest way, which is to download a model that has been trained by us. +You can find models [here](https://github.com/legalforce-research/vaporetto/releases/tag/v0.3.0). + +We chose `bccwj-suw+unidic+tag`: +``` +% wget https://github.com/legalforce-research/vaporetto/releases/download/v0.3.0/bccwj-suw+unidic+tag.tar.xz +``` + +Each file contains a model file and license terms, so you need to extract the downloaded file like the following command: +``` +% tar xf ./bccwj-suw+unidic+tag.tar.xz +``` + +To perform tokenization, run the following command: +``` +% echo 'ヴェネツィアはイタリアにあります。' | cargo run --release -p predict -- --model path/to/bccwj-suw+unidic+tag.model.zst +``` + +The following will be output: +``` +ヴェネツィア は イタリア に あり ます 。 +``` #### Convert KyTea's Model -The first is the simplest way, which is to convert a model that has been trained by KyTea. +The second is also a simple way, which is to convert a model that has been trained by KyTea. First of all, download the model of your choice from the [KyTea Models](http://www.phontron.com/kytea/model.html) page. We chose `jp-0.4.7-5.mod.gz`: @@ -34,24 +60,23 @@ Each model is compressed, so you need to decompress the downloaded model file li ``` To convert a KyTea model into a Vaporetto model, run the following command in the Vaporetto root directory. -If necessary, the Rust code will be compiled before the conversion process. ``` -% cargo run --release -p convert_kytea_model -- --model-in path/to/jp-0.4.7-5.mod --model-out path/to/jp-0.4.7-5-tokenize.model.zstd +% cargo run --release -p convert_kytea_model -- --model-in path/to/jp-0.4.7-5.mod --model-out path/to/jp-0.4.7-5-tokenize.model.zst ``` Now you can perform tokenization. Run the following command: ``` -% echo '火星猫の生態の調査結果' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize.model.zstd +% echo 'ヴェネツィアはイタリアにあります。' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize.model.zst ``` The following will be output: ``` -火星 猫 の 生態 の 調査 結果 +ヴェネツィア は イタリア に あ り ま す 。 ``` #### Train Your Own Model -The second way, which is mainly for researchers, is to prepare your own training corpus and train your own tokenization models. +The third way, which is mainly for researchers, is to prepare your own training corpus and train your own tokenization models. Vaporetto can train from two types of corpora: fully annotated corpora and partially annotated corpora. @@ -59,7 +84,7 @@ Fully annotated corpora are corpora in which all character boundaries are annota This is the data in the form of spaces inserted into the boundaries of the tokens, as shown below: ``` -ヴェネツィア は イタリア に あ り ま す 。 +ヴェネツィア は イタリア に あり ます 。 火星 猫 の 生態 の 調査 結果 ``` @@ -75,56 +100,122 @@ Here is an example: To train a model, use the following command: ``` -% cargo run --release -p train -- --model ./your.model.zstd --tok path/to/full.txt --part path/to/part.txt --dict path/to/dict.txt +% cargo run --release -p train -- --model ./your.model.zst --tok path/to/full.txt --part path/to/part.txt --dict path/to/dict.txt ``` `--tok` argument specifies a fully annotated corpus, and `--part` argument specifies a partially annotated corpus. You can also specify a word dictionary with `--dict` argument. A word dictionary is a file with words per line. +The trainer does not accept empty lines. +Therefore, remove all empty lines from the corpus before training. + You can specify all arguments above multiple times. +### Model Manipulation + +Sometimes, your model will output different results than what you expect. +For example, `メロンパン` is split into two tokens in the following command. +We use `--scores` option to show the score of each character boundary: +``` +% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize.model.zst +朝食 は メロン パン 1 個 だっ た +0:朝食 -15398 +1:食は 24623 +2:はメ 30261 +3:メロ -26885 +4:ロン -38896 +5:ンパ 8162 +6:パン -23416 +7:ン1 23513 +8:1個 18435 +9:個だ 24964 +10:だっ -15065 +11:った 14178 +``` + +To concatenate `メロンパン` into a single token, manipulate the model in the following steps so that the score of `ンパ` becomes negative: + +1. Dump a dictionary by the following command: + ``` + % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --dump-dict path/to/dictionary.csv + ``` + +2. Edit the dictionary. + + The dictionary is a csv file. Each row contains a word, corresponding weights, and a comment in the following order: + + * `right_weight` - A weight that is added when the word is found to the right of the boundary. + * `inside_weight` - A weight that is added when the word is overlapped on the boundary. + * `left_weight` - A weight that is added when the word is found to the left of the boundary. + * `comment` - A comment that does not affect the behaviour. + + Vaporetto splits a text when the total weight of the boundary is a positive number, so we add a new entry as follows: + ```diff + メロレオストーシス,6944,-2553,5319, + メロン,8924,-10861,7081, + +メロンパン,0,-100000,0,melon🍈 bread🍞 in English. + メロン果実,4168,-1165,3558, + メロヴィング,6999,-15413,7583, + ``` + + In this case, `-100000` will be added when the boundary is inside of the word `メロンパン`. + + Note that Vaporetto uses 32-bit integers for the total weight, so you have to be careful about overflow. + + In addition, The dictionary cannot contain duplicated words. + When the word is already contained in the dictionary, you have to edit existing weights. + +3. Replaces weight data of a model file + ``` + % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --replace-dict path/to/dictionary.csv --model-out path/to/jp-0.4.7-5-tokenize-new.model.zst + ``` + +Now `メロンパン` is split into a single token. +``` +% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize-new.model.zst +朝食 は メロンパン 1 個 だっ た +0:朝食 -15398 +1:食は 24623 +2:はメ 30261 +3:メロ -126885 +4:ロン -138896 +5:ンパ -91838 +6:パン -123416 +7:ン1 23513 +8:1個 18435 +9:個だ 24964 +10:だっ -15065 +11:った 14178 +``` + +### POS tagging + +Vaporetto experimentally supports POS tagging. + +To train tags, add a slash and tag name following each token in the dataset as follows: + +* For fully annotated corpora + ``` + この/連体詞 人/名詞 は/助詞 火星/名詞 人/接尾辞 です/助動詞 + ``` + +* For partially annotated corpora + ``` + ヴ-ェ-ネ-ツ-ィ-ア/名詞|は/助詞|イ-タ-リ-ア/名詞|に/助詞|あ-り ま-す + ``` + +If the dataset contains tags, the `train` command automatically trains them. + +In prediction, tags are not predicted by default, so you have to specify `--predict-tags` argument to `predict` command if necessary. + ## Speed Comparison of Various Tokenizers -### Experimental Setup - -* Document: Japanese training data of Kyoto Free Translation Task -* Models: - * KyTea and Vaporetto: Compact LR model (jp-0.4.7-6) - * MeCab, Kuromoji, and Lindera: IPAdic - * Sudachi and Sudachi.rs: system_core.dic (v20210802) - -### Results - -* VM instance on Google Cloud Platform (c2-standard-16, Debian) - - | Tool Name (version) | Speed (×10^6 chars/s) | σ | - | -------------------------- | ---------------------:|-------| - | KyTea (2020-04-03) | 0.777 | 0.020 | - | Vaporetto (0.1.6) | **4.426** | 0.182 | - | | | | - | MeCab (2020-09-14) | 2.736 | 0.041 | - | | | | - | Kuromoji (Atilika's 0.9.0) | 0.423 | 0.013 | - | Lindera (0.8.0) | 1.002 | 0.014 | - | | | | - | Sudachi (0.5.2) | 0.251 | 0.012 | - | Sudachi.rs (0.6.0-rc1) | 0.644 | 0.012 | - -* MacBook Pro (2017, Processor: 2.3 GHz Intel Core i5, Memory: 8 GB 2133 MHz LPDDR3) - - | Tool Name (version) | Speed (×10^6 chars/s) | σ | - | -------------------------- | ---------------------:|-------| - | KyTea (2020-04-03) | 0.490 | 0.003 | - | Vaporetto (0.1.6) | **3.016** | 0.113 | - | | | | - | MeCab (2020-09-14) | 1.418 | 0.007 | - | | | | - | Kuromoji (Atilika's 0.9.0) | 1.197 | 0.034 | - | Lindera (0.8.0) | 0.542 | 0.010 | - | | | | - | Sudachi (0.5.2) | 0.439 | 0.026 | - | Sudachi.rs (0.6.0-rc1) | 0.427 | 0.009 | +Vaporetto is 8.25 times faster than KyTea. + +Details can be found [here](https://github.com/legalforce-research/vaporetto/wiki/Speed-Comparison). + +![](./figures/comparison.svg) ## Disclaimer diff --git a/bench/README.md b/bench/README.md deleted file mode 100644 index f66549b2..00000000 --- a/bench/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Benchmarking of various tokenizers - -## Preparation - -``` -% git submodule update --init -% ./download_resources.sh -% ./compile_all.sh -``` - -## Measurement - -``` -% ./run_all.sh 2>&1 | tee ./results -% ./stats.py < ./results -``` diff --git a/bench/compile_all.sh b/bench/compile_all.sh deleted file mode 100755 index d9e3fabe..00000000 --- a/bench/compile_all.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash - -set -eux - -which patch -which cargo -which autoreconf -which libtool -which make -which mvn - -set +e - -patch -p1 -N < ./elapsed_time.patch - -set -e - -pushd .. -cargo build --release -./target/release/convert_kytea_model --model-in "./bench/kytea/jp-0.4.7-6.mod" --model-out "./jp-0.4.7-6.tokenize.mod" -popd - -pushd ./kytea -autoreconf -i -./configure -make -popd - -pushd ./mecab/mecab -./configure --prefix=$(cd .. && pwd)/tmpusr -make -make install -popd -pushd ./mecab/mecab-ipadic -./configure --with-charset=utf8 --prefix=$(cd .. && pwd)/tmpusr --with-mecab-config=../mecab/mecab-config -make -make install -popd - -pushd ./kuromoji -mvn compile -popd - -pushd ./lindera -cargo build --release -popd - -pushd ./sudachi -mvn compile -popd - -pushd ./sudachi.rs -cargo build --release -popd diff --git a/bench/download_resources.sh b/bench/download_resources.sh deleted file mode 100755 index 4f5e91df..00000000 --- a/bench/download_resources.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -set -eux - -which wget -which gunzip -which unzip -which tar - -pushd ./kytea -wget "http://www.phontron.com/kytea/download/model/jp-0.4.7-6.mod.gz" -gunzip "./jp-0.4.7-6.mod.gz" -popd -pushd ./sudachi -wget "http://sudachi.s3-website-ap-northeast-1.amazonaws.com/sudachidict/sudachi-dictionary-20210802-core.zip" -unzip "./sudachi-dictionary-20210802-core.zip" -popd -pushd ./sudachi.rs -./fetch_dictionary.sh -popd - -wget "http://www.phontron.com/kftt/download/kftt-data-1.0.tar.gz" -tar xf "./kftt-data-1.0.tar.gz" diff --git a/bench/elapsed_time.patch b/bench/elapsed_time.patch deleted file mode 100644 index 9f5211a1..00000000 --- a/bench/elapsed_time.patch +++ /dev/null @@ -1,114 +0,0 @@ ---- a/kytea/src/lib/kytea.cpp -+++ b/kytea/src/lib/kytea.cpp -@@ -19,6 +19,7 @@ - #include - #include - #include -+#include - #include - #include - #include -@@ -1206,6 +1207,8 @@ void Kytea::analyze() { - for(int i = 0; i < config_->getNumTags(); i++) - out->setDoTag(i,config_->getDoTag(i)); - -+ chrono::steady_clock::time_point begin = chrono::steady_clock::now(); -+ - KyteaSentence* next; - while((next = in->readSentence()) != 0) { - if(config_->getDoWS()) -@@ -1218,6 +1221,9 @@ void Kytea::analyze() { - delete next; - } - -+ chrono::steady_clock::time_point end = chrono::steady_clock::now(); -+ cerr << "Elapsed-kytea: " << (double) chrono::duration_cast(end - begin).count() / 1000 << " [sec]" << endl; -+ - delete in; - delete out; - if(inStr) delete inStr; ---- a/mecab/mecab/src/tagger.cpp -+++ b/mecab/mecab/src/tagger.cpp -@@ -6,6 +6,7 @@ - #include - #include - #include -+#include - #include "common.h" - #include "connector.h" - #include "mecab.h" -@@ -1229,6 +1230,8 @@ int mecab_do(int argc, char **argv) { - WHAT_ERROR("cannot create tagger"); - } - -+ std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); -+ - for (size_t i = 0; i < rest.size(); ++i) { - MeCab::istream_wrapper ifs(rest[i].c_str()); - if (!*ifs) { -@@ -1255,6 +1258,8 @@ int mecab_do(int argc, char **argv) { - std::strncpy(ibuf, sentence.c_str(), ibufsize); - } - if (ifs->eof() && !ibuf[0]) { -+ std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); -+ std::cerr << "Elapsed-mecab: " << (double) std::chrono::duration_cast(end - begin).count() / 1000 << " [sec]" << std::endl; - return false; - } - if (ifs->fail()) { ---- a/lindera/lindera-cli/src/main.rs -+++ b/lindera/lindera-cli/src/main.rs -@@ -2,6 +2,7 @@ use std::fs; - use std::io; - use std::io::{BufRead, BufReader}; - use std::path::Path; -+use std::time::Instant; - - use clap::{crate_authors, crate_description, crate_version, App, AppSettings, Arg}; - -@@ -123,6 +124,8 @@ fn main() -> LinderaResult<()> { - Box::new(BufReader::new(io::stdin())) - }; - -+ let start = Instant::now(); -+ - loop { - // read the text to be tokenized from stdin - let mut text = String::new(); -@@ -145,5 +148,8 @@ fn main() -> LinderaResult<()> { - }; - } - -+ let duration = start.elapsed(); -+ eprintln!("Elapsed-lindera: {} [sec]", duration.as_secs_f64()); -+ - Ok(()) - } ---- a/sudachi.rs/sudachi-cli/src/main.rs -+++ b/sudachi.rs/sudachi-cli/src/main.rs -@@ -20,6 +20,7 @@ use std::fs::File; - use std::io::{self, BufRead, BufReader, BufWriter, Write}; - use std::path::PathBuf; - use std::process; -+use std::time::Instant; - - use structopt::StructOpt; - -@@ -132,6 +133,8 @@ fn main() { - - let format = make_output::<&JapaneseDictionary>(&args); - -+ let start = Instant::now(); -+ - // tokenize and output results - for line in reader.lines() { - let input = line.expect("Failed to read line"); -@@ -157,6 +160,9 @@ fn main() { - } - // it is recommended to call write before dropping BufWriter - writer.flush().expect("flush failed"); -+ -+ let duration = start.elapsed(); -+ eprintln!("Elapsed-sudachi.rs: {} [sec]", duration.as_secs_f64()); - } - - fn make_output(cli: &Cli) -> Box> { diff --git a/bench/kuromoji/pom.xml b/bench/kuromoji/pom.xml deleted file mode 100644 index 5f88e1a3..00000000 --- a/bench/kuromoji/pom.xml +++ /dev/null @@ -1,72 +0,0 @@ - - - - 4.0.0 - - kuromoji_bench - kuromoji_bench - 1.0-SNAPSHOT - - kuromoji_bench - - - UTF-8 - 1.7 - 1.7 - - - - - com.atilika.kuromoji - kuromoji-ipadic - 0.9.0 - - - - - - - - - maven-clean-plugin - 3.1.0 - - - - maven-resources-plugin - 3.0.2 - - - maven-compiler-plugin - 3.8.0 - - - maven-surefire-plugin - 2.22.1 - - - maven-jar-plugin - 3.0.2 - - - maven-install-plugin - 2.5.2 - - - maven-deploy-plugin - 2.8.2 - - - - maven-site-plugin - 3.7.1 - - - maven-project-info-reports-plugin - 3.0.0 - - - - - diff --git a/bench/kuromoji/src/main/java/kuromoji_bench/App.java b/bench/kuromoji/src/main/java/kuromoji_bench/App.java deleted file mode 100644 index 7f347d38..00000000 --- a/bench/kuromoji/src/main/java/kuromoji_bench/App.java +++ /dev/null @@ -1,28 +0,0 @@ -package kuromoji_bench; - -import com.atilika.kuromoji.ipadic.Token; -import com.atilika.kuromoji.ipadic.Tokenizer; -import java.util.List; -import java.util.ArrayList; -import java.util.Scanner; -import java.time.Instant; -import java.time.Duration; - -public class App { - public static void main(String[] args) { - Tokenizer tokenizer = new Tokenizer(); - Scanner input = new Scanner(System.in); - Instant start = Instant.now(); - while (input.hasNext()) { - List tokens = tokenizer.tokenize(input.nextLine()); - List words = new ArrayList(); - for (Token token : tokens) { - words.add(token.getSurface()); - } - System.out.println(String.join(" ", words)); - } - Instant finish = Instant.now(); - double timeElapsed = (double) Duration.between(start, finish).toMillis() / 1000; - System.err.println("Elapsed-kuromoji: " + timeElapsed + " [sec]"); - } -} diff --git a/bench/kytea b/bench/kytea deleted file mode 160000 index 73a94c4a..00000000 --- a/bench/kytea +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 73a94c4a3045087a7e90f27700f3b870a72625e7 diff --git a/bench/lindera b/bench/lindera deleted file mode 160000 index 0f500336..00000000 --- a/bench/lindera +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0f50033653631261a290ae4ac94cc16bfe63f3bb diff --git a/bench/mecab b/bench/mecab deleted file mode 160000 index 046fa78b..00000000 --- a/bench/mecab +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 046fa78b2ed56fbd4fac312040f6d62fc1bc31e3 diff --git a/bench/run_all.sh b/bench/run_all.sh deleted file mode 100755 index 6b6f2365..00000000 --- a/bench/run_all.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -set -eux - -INPUT_DATA="./kftt-data-1.0/data/orig/kyoto-train.ja" - -for i in 0 1 2 3 4 5 6 7 8 9 -do - for j in 0 1 2 3 4 5 6 7 8 9 - do - echo "iter" $i $j - - ./kytea/src/bin/kytea -model "./kytea/jp-0.4.7-6.mod" -notags < $INPUT_DATA > /dev/null - - ../target/release/predict --model "../jp-0.4.7-6.tokenize.mod" < $INPUT_DATA > /dev/null - - ./mecab/tmpusr/bin/mecab -Owakati < $INPUT_DATA > /dev/null - - pushd ./kuromoji - mvn exec:java -Dexec.mainClass=kuromoji_bench.App < ../$INPUT_DATA > /dev/null - popd - - ./lindera/target/release/lindera -O wakati < $INPUT_DATA > /dev/null - - pushd ./sudachi - mvn exec:java -Dexec.mainClass=sudachi_bench.App < ../$INPUT_DATA > /dev/null - popd - - ./sudachi.rs/target/release/sudachi -w -m C < $INPUT_DATA > /dev/null - done -done diff --git a/bench/stats.py b/bench/stats.py deleted file mode 100755 index 9004493f..00000000 --- a/bench/stats.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 - -from __future__ import annotations - -import collections -import math -import re -import sys - - -RE_DICT = [ - ('kytea', re.compile(r'Elapsed-kytea: ([0-9\.]+) \[sec\]')), - ('vaporetto', re.compile(r'Elapsed: ([0-9\.]+) \[sec\]')), - ('mecab', re.compile(r'Elapsed-mecab: ([0-9\.]+) \[sec\]')), - ('kuromoji', re.compile(r'Elapsed-kuromoji: ([0-9\.]+) \[sec\]')), - ('lindera', re.compile(r'Elapsed-lindera: ([0-9\.]+) \[sec\]')), - ('sudachi', re.compile(r'Elapsed-sudachi: ([0-9\.]+) \[sec\]')), - ('sudachi.rs', re.compile(r'Elapsed-sudachi.rs: ([0-9\.]+) \[sec\]')), -] - -N_CHARS = 16318893 - - -def mean_std(times: list[float]) -> (float, float): - speeds = [N_CHARS / time for time in times] - mean = sum(speeds) / len(speeds) - dist = sum((speed - mean) ** 2 for speed in speeds) / len(speeds) - return mean, math.sqrt(dist) - - -def _main(): - times = collections.defaultdict(list) - for line in sys.stdin: - for name, r in RE_DICT: - m = r.match(line) - if m is not None: - times[name].append(float(m.group(1))) - break - - for name, _ in RE_DICT: - mean, std = mean_std(times[name]) - print(f'{name} {mean} {std}') - - -if __name__ == '__main__': - _main() diff --git a/bench/sudachi.rs b/bench/sudachi.rs deleted file mode 160000 index 1cf62ec2..00000000 --- a/bench/sudachi.rs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1cf62ec2d6949db76e5aa2625c9b76f747960ac1 diff --git a/bench/sudachi/pom.xml b/bench/sudachi/pom.xml deleted file mode 100644 index 26d7f26d..00000000 --- a/bench/sudachi/pom.xml +++ /dev/null @@ -1,72 +0,0 @@ - - - - 4.0.0 - - sudachi_bench - sudachi_bench - 1.0-SNAPSHOT - - sudachi_bench - - - UTF-8 - 1.7 - 1.7 - - - - - com.worksap.nlp - sudachi - 0.5.2 - - - - - - - - - maven-clean-plugin - 3.1.0 - - - - maven-resources-plugin - 3.0.2 - - - maven-compiler-plugin - 3.8.0 - - - maven-surefire-plugin - 2.22.1 - - - maven-jar-plugin - 3.0.2 - - - maven-install-plugin - 2.5.2 - - - maven-deploy-plugin - 2.8.2 - - - - maven-site-plugin - 3.7.1 - - - maven-project-info-reports-plugin - 3.0.0 - - - - - diff --git a/bench/sudachi/src/main/java/sudachi_bench/App.java b/bench/sudachi/src/main/java/sudachi_bench/App.java deleted file mode 100644 index ac249c98..00000000 --- a/bench/sudachi/src/main/java/sudachi_bench/App.java +++ /dev/null @@ -1,36 +0,0 @@ -package sudachi_bench; - -import java.io.IOException; -import com.worksap.nlp.sudachi.Tokenizer; -import com.worksap.nlp.sudachi.Dictionary; -import com.worksap.nlp.sudachi.DictionaryFactory; -import com.worksap.nlp.sudachi.Morpheme; -import java.util.List; -import java.util.ArrayList; -import java.util.Scanner; -import java.time.Instant; -import java.time.Duration; -import java.nio.file.Paths; -import java.nio.file.Files; - -public class App { - public static void main(String[] args) throws IOException { - String settings = Files.readString(Paths.get("sudachi.json")); - Scanner input = new Scanner(System.in); - try (Dictionary dict = new DictionaryFactory().create(settings)) { - Tokenizer tokenizer = dict.create(); - Instant start = Instant.now(); - while (input.hasNext()) { - List tokens = tokenizer.tokenize(Tokenizer.SplitMode.C, input.nextLine()); - List words = new ArrayList(); - for (Morpheme token : tokens) { - words.add(token.surface()); - } - System.out.println(String.join(" ", words)); - } - Instant finish = Instant.now(); - double timeElapsed = (double) Duration.between(start, finish).toMillis() / 1000; - System.err.println("Elapsed-sudachi: " + timeElapsed + " [sec]"); - } - } -} diff --git a/bench/sudachi/sudachi.json b/bench/sudachi/sudachi.json deleted file mode 100644 index 9a94c67c..00000000 --- a/bench/sudachi/sudachi.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "systemDict" : "sudachi-dictionary-20210802/system_core.dic", - "inputTextPlugin" : [ - { "class" : "com.worksap.nlp.sudachi.DefaultInputTextPlugin" }, - { "class" : "com.worksap.nlp.sudachi.ProlongedSoundMarkInputTextPlugin", - "prolongedSoundMarks": ["ー", "-", "⁓", "〜", "〰"], - "replacementSymbol": "ー"} - ], - "oovProviderPlugin" : [ - { "class" : "com.worksap.nlp.sudachi.MeCabOovProviderPlugin" }, - { "class" : "com.worksap.nlp.sudachi.SimpleOovProviderPlugin", - "oovPOS" : [ "補助記号", "一般", "*", "*", "*", "*" ], - "leftId" : 5968, - "rightId" : 5968, - "cost" : 3857 } - ], - "pathRewritePlugin" : [ - { "class" : "com.worksap.nlp.sudachi.JoinNumericPlugin", - "joinKanjiNumeric" : true }, - { "class" : "com.worksap.nlp.sudachi.JoinKatakanaOovPlugin", - "oovPOS" : [ "名詞", "普通名詞", "一般", "*", "*", "*" ], - "minLength" : 3 - } - ] -} diff --git a/evaluate/src/main.rs b/evaluate/src/main.rs index 8dfeb791..99a28e8e 100644 --- a/evaluate/src/main.rs +++ b/evaluate/src/main.rs @@ -33,6 +33,23 @@ impl FromStr for WsConst { } } +#[derive(Debug)] +enum EvaluationMetric { + CharBoundaryAccuracy, + WordAccuracy, +} + +impl FromStr for EvaluationMetric { + type Err = &'static str; + fn from_str(metric: &str) -> Result { + match metric { + "char" => Ok(Self::CharBoundaryAccuracy), + "word" => Ok(Self::WordAccuracy), + _ => Err("Could not parse a metric value"), + } + } +} + #[derive(StructOpt, Debug)] #[structopt( name = "evaluate", @@ -43,6 +60,10 @@ struct Opt { #[structopt(long)] model: PathBuf, + /// Predicts POS tags. + #[structopt(long)] + predict_tags: bool, + /// Do not segment some character types: {D, R, H, T, K, O, G}. /// D: Digit, R: Roman, H: Hiragana, T: Katakana, K: Kanji, O: Other, G: Grapheme cluster. #[structopt(long)] @@ -51,18 +72,22 @@ struct Opt { /// Do not normalize input strings before prediction. #[structopt(long)] no_norm: bool, + + /// Evaluation metric: {char, word}. + /// char: evaluates each charactor boundary. + /// word: evaluates each word using Nagata's method. + #[structopt(long, default_value = "char")] + metric: EvaluationMetric, } fn main() -> Result<(), Box> { let opt = Opt::from_args(); - let fullwidth_filter = KyteaFullwidthFilter::new(); + let fullwidth_filter = KyteaFullwidthFilter; let mut post_filters: Vec> = vec![]; for wsconst in &opt.wsconst { match wsconst { - WsConst::GraphemeCluster => { - post_filters.push(Box::new(ConcatGraphemeClustersFilter::new())) - } + WsConst::GraphemeCluster => post_filters.push(Box::new(ConcatGraphemeClustersFilter)), WsConst::CharType(char_type) => { post_filters.push(Box::new(KyteaWsConstFilter::new(*char_type))) } @@ -72,45 +97,103 @@ fn main() -> Result<(), Box> { eprintln!("Loading model file..."); let mut f = zstd::Decoder::new(File::open(opt.model)?)?; let model = Model::read(&mut f)?; - let predictor = Predictor::new(model); + let predictor = Predictor::new(model, opt.predict_tags)?; eprintln!("Start tokenization"); - let mut n_true_positive = 0; - let mut n_false_positive = 0; - let mut n_false_negative = 0; + let mut results = vec![]; for line in stdin().lock().lines() { - let s = Sentence::from_tokenized(line?)?; - let s = if opt.no_norm { - s - } else { + let line = line?; + if line.is_empty() { + continue; + } + let mut s = Sentence::from_tokenized(line)?; + let ref_boundaries = s.boundaries().to_vec(); + let ref_tags = s.tags().to_vec(); + if !opt.no_norm { let new_line = fullwidth_filter.filter(s.to_raw_string()); - let mut new_s = Sentence::from_raw(new_line)?; - new_s.boundaries_mut().clone_from_slice(s.boundaries()); - new_s + s = Sentence::from_raw(new_line)? }; - let reference = s.boundaries().to_vec(); - let s = predictor.predict(s); - let s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); - for (&r, &h) in reference.iter().zip(s.boundaries()) { - if r == h { - if h == BoundaryType::WordBoundary { - n_true_positive += 1; + s = predictor.predict(s); + s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); + s = predictor.fill_tags(s); + let hyp_boundaries = s.boundaries().to_vec(); + let hyp_tags = s.tags().to_vec(); + results.push((ref_boundaries, ref_tags, hyp_boundaries, hyp_tags)); + } + + match opt.metric { + EvaluationMetric::CharBoundaryAccuracy => { + let mut n_tp = 0; + let mut n_tn = 0; + let mut n_fp = 0; + let mut n_fn = 0; + for (rs_b, _, hs_b, _) in results { + for (r, h) in rs_b.into_iter().zip(hs_b) { + if r == h { + if h == BoundaryType::WordBoundary { + n_tp += 1; + } else { + n_tn += 1; + } + } else if h == BoundaryType::WordBoundary { + n_fp += 1; + } else { + n_fn += 1; + } } - } else if h == BoundaryType::WordBoundary { - n_false_positive += 1; - } else { - n_false_negative += 1; } + let precision = n_tp as f64 / (n_tp + n_fp) as f64; + let recall = n_tp as f64 / (n_tp + n_fn) as f64; + let f1 = 2. * precision * recall / (precision + recall); + println!("Precision: {}", precision); + println!("Recall: {}", recall); + println!("F1: {}", f1); + println!("TP: {}, TN: {}, FP: {}, FN: {}", n_tp, n_tn, n_fp, n_fn); + } + EvaluationMetric::WordAccuracy => { + // Reference: + // Masaaki Nagata. 1994. A stochastic Japanese morphological analyzer using a forward-DP + // backward-A* n-best search algorithm. In COLING 1994 Volume 1: The 15th International + // Conference on Computational Linguistics. + let mut n_sys = 0; + let mut n_ref = 0; + let mut n_cor = 0; + for (rs_b, rs_t, hs_b, hs_t) in results { + let mut matched = true; + for (((r_b, r_t), h_b), h_t) in rs_b.iter().zip(&rs_t).zip(&hs_b).zip(&hs_t) { + if r_b == h_b { + if *h_b == BoundaryType::WordBoundary { + if matched && r_t == h_t { + n_cor += 1; + } + matched = true; + n_ref += 1; + n_sys += 1; + } + } else { + if *h_b == BoundaryType::WordBoundary { + n_sys += 1; + } else { + n_ref += 1; + } + matched = false; + } + } + if matched && rs_t.last().unwrap() == hs_t.last().unwrap() { + n_cor += 1; + } + n_sys += 1; + n_ref += 1; + } + let precision = n_cor as f64 / n_sys as f64; + let recall = n_cor as f64 / n_ref as f64; + let f1 = 2. * precision * recall / (precision + recall); + println!("Precision: {}", precision); + println!("Recall: {}", recall); + println!("F1: {}", f1); } } - let precision = n_true_positive as f64 / (n_true_positive + n_false_positive) as f64; - let recall = n_true_positive as f64 / (n_true_positive + n_false_negative) as f64; - let f1 = 2. * precision * recall / (precision + recall); - println!("Precision: {}", precision); - println!("Recall: {}", recall); - println!("F1: {}", f1); - Ok(()) } diff --git a/figures/comparison.ngp b/figures/comparison.ngp new file mode 100644 index 00000000..2d68e03b --- /dev/null +++ b/figures/comparison.ngp @@ -0,0 +1,1499 @@ +#!ngraph +#%creator: Ngraph +#%version: 6.09.03 +new axis name:fX1 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=13 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=2200 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=left + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=1200 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fY1 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0.5 + axis::max=2.5 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=2200 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=right + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fU1 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=200 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:0' + axis::gauge=right + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=left + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fR1 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=15900 + axis::y=2200 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:1' + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=left + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +axis::grouping 1 0 1 2 3 + +new axis name:fX2 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=13 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=3400 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=left + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=1200 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fY2 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0.5 + axis::max=1.5 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=3400 + axis::direction=9000 + axis::baseline=true + axis::length=1000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=right + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fU2 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=2400 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:4' + axis::gauge=right + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=left + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fR2 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=15900 + axis::y=3400 + axis::direction=9000 + axis::baseline=true + axis::length=1000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:5' + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=left + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +axis::grouping 1 4 5 6 7 + +new axis name:fX3 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=13 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=5600 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=left + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=1200 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fY3 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0.5 + axis::max=2.5 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=5600 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=right + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fU3 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=3600 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:8' + axis::gauge=right + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=left + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fR3 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=15900 + axis::y=5600 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:9' + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=left + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +axis::grouping 1 8 9 10 11 + +new axis name:fX4 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=13 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=7800 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=left + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=right + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=1200 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fY4 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0.5 + axis::max=2.5 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=7800 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=right + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fU4 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=4600 + axis::y=5800 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:12' + axis::gauge=right + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=left + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fR4 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=15900 + axis::y=7800 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:13' + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=left + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +axis::grouping 1 12 13 14 15 + +new data + data::hidden=false + data::R=0 + data::G=0 + data::B=0 + data::A=255 + data::clip=true + data::redraw_flag=true + data::source=file + data::save_path=relative + data::x=2 + data::y=0 + data::type=bar_fill_x + data::interpolation=spline + data::fit= + data::math_x= + data::math_y='3-Y' + data::func_f= + data::func_g= + data::func_h= + data::smooth_x=0 + data::smooth_y=0 + data::averaging_type=simple + data::mark_type=0 + data::mark_size=200 + data::line_width=40 + data::line_style= + data::line_join=bevel + data::line_miter_limit=1000 + data::R2=0 + data::G2=0 + data::B2=0 + data::A2=255 + data::remark='#%'\''' + data::ifs=',' + data::csv=false + data::head_skip=1 + data::read_step=1 + data::final_line=3 + data::mask= + data::move_data= + data::move_data_x= + data::move_data_y= + data::axis_x='axis:0' + data::axis_y='axis:1' + data::data_clip=true + data::range_min=1 + data::range_max=10 + data::range_div=512 + data::array= + data::file='./comparison.txt' + +new data + data::hidden=false + data::R=0 + data::G=0 + data::B=0 + data::A=255 + data::clip=true + data::redraw_flag=true + data::source=file + data::save_path=relative + data::x=2 + data::y=0 + data::type=bar_fill_x + data::interpolation=spline + data::fit= + data::math_x= + data::math_y= + data::func_f= + data::func_g= + data::func_h= + data::smooth_x=0 + data::smooth_y=0 + data::averaging_type=simple + data::mark_type=0 + data::mark_size=200 + data::line_width=40 + data::line_style= + data::line_join=bevel + data::line_miter_limit=1000 + data::R2=0 + data::G2=0 + data::B2=0 + data::A2=255 + data::remark='#%'\''' + data::ifs=',' + data::csv=false + data::head_skip=3 + data::read_step=1 + data::final_line=4 + data::mask= + data::move_data= + data::move_data_x= + data::move_data_y= + data::axis_x='axis:4' + data::axis_y='axis:5' + data::data_clip=true + data::range_min=1 + data::range_max=10 + data::range_div=512 + data::array= + data::file='./comparison.txt' + +new data + data::hidden=false + data::R=0 + data::G=0 + data::B=0 + data::A=255 + data::clip=true + data::redraw_flag=true + data::source=file + data::save_path=relative + data::x=2 + data::y=0 + data::type=bar_fill_x + data::interpolation=spline + data::fit= + data::math_x= + data::math_y='3-Y' + data::func_f= + data::func_g= + data::func_h= + data::smooth_x=0 + data::smooth_y=0 + data::averaging_type=simple + data::mark_type=0 + data::mark_size=200 + data::line_width=40 + data::line_style= + data::line_join=bevel + data::line_miter_limit=1000 + data::R2=0 + data::G2=0 + data::B2=0 + data::A2=255 + data::remark='#%'\''' + data::ifs=',' + data::csv=false + data::head_skip=4 + data::read_step=1 + data::final_line=6 + data::mask= + data::move_data= + data::move_data_x= + data::move_data_y= + data::axis_x='axis:8' + data::axis_y='axis:9' + data::data_clip=true + data::range_min=1 + data::range_max=10 + data::range_div=512 + data::array= + data::file='./comparison.txt' + +new data + data::hidden=false + data::R=0 + data::G=0 + data::B=0 + data::A=255 + data::clip=true + data::redraw_flag=true + data::source=file + data::save_path=relative + data::x=2 + data::y=0 + data::type=bar_fill_x + data::interpolation=spline + data::fit= + data::math_x= + data::math_y='3-Y' + data::func_f= + data::func_g= + data::func_h= + data::smooth_x=0 + data::smooth_y=0 + data::averaging_type=simple + data::mark_type=0 + data::mark_size=200 + data::line_width=40 + data::line_style= + data::line_join=bevel + data::line_miter_limit=1000 + data::R2=0 + data::G2=0 + data::B2=0 + data::A2=255 + data::remark='#%'\''' + data::ifs=',' + data::csv=false + data::head_skip=6 + data::read_step=1 + data::final_line=8 + data::mask= + data::move_data= + data::move_data_x= + data::move_data_y= + data::axis_x='axis:12' + data::axis_y='axis:13' + data::data_clip=true + data::range_min=1 + data::range_max=10 + data::range_div=512 + data::array= + data::file='./comparison.txt' + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='KyTea (2020-04-03)' + text::x=200 + text::y=800 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Vaporetto (0.3.0)' + text::x=200 + text::y=1800 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Analysis Speed [×10^6@ chars/s]' + text::x=5400 + text::y=9000 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='MeCab (2020-09-14)' + text::x=200 + text::y=3000 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Kuromoji (0.9.0)' + text::x=200 + text::y=4200 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Lindera (0.8.1)' + text::x=200 + text::y=5200 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Sudachi (0.5.3)' + text::x=200 + text::y=6400 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='sudachi.rs (0.6.2)' + text::x=200 + text::y=7400 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new gra name:viewer + gra::left_margin=0 + gra::top_margin=0 + gra::zoom=10000 + gra::paper_width=16000 + gra::paper_height=9300 + gra::decimalsign=period + gra::draw_obj='axisgrid axis data merge legend rectangle arc path mark text' diff --git a/figures/comparison.svg b/figures/comparison.svg new file mode 100644 index 00000000..211ccde0 --- /dev/null +++ b/figures/comparison.svg @@ -0,0 +1,181 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/figures/comparison.txt b/figures/comparison.txt new file mode 100644 index 00000000..8d480d0e --- /dev/null +++ b/figures/comparison.txt @@ -0,0 +1,8 @@ +Tool Name (version),Speed [M chars/s] +KyTea (2020-04-03),1.460792571965816 +Vaporetto (0.3.0),12.058625370294768 +MeCab (2020-09-14),4.606421319380639 +Kuromoji (0.9.0),1.4720151761081284 +Lindera (0.8.1),1.4635669561072055 +Sudachi (0.5.3),0.3197113821666647 +sudachi.rs (0.6.2),0.9940936718943929 diff --git a/manipulate_model/Cargo.toml b/manipulate_model/Cargo.toml new file mode 100644 index 00000000..3c27d5b9 --- /dev/null +++ b/manipulate_model/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "manipulate_model" +version = "0.1.0" +edition = "2018" + +[dependencies] +csv = "1.1" # Unlicense or MIT +serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0 +structopt = "0.3" # MIT or Apache-2.0 +vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0 +zstd = "0.9" # MIT diff --git a/manipulate_model/src/main.rs b/manipulate_model/src/main.rs new file mode 100644 index 00000000..0f159df3 --- /dev/null +++ b/manipulate_model/src/main.rs @@ -0,0 +1,88 @@ +use std::fs; +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; +use structopt::StructOpt; +use vaporetto::{Model, WordWeightRecord}; + +#[derive(StructOpt, Debug)] +#[structopt( + name = "manipulate_model", + about = "A program to manipulate tarined models." +)] +struct Opt { + /// Input path of the model file + #[structopt(long)] + model_in: PathBuf, + + /// Output path of the model file + #[structopt(long)] + model_out: Option, + + /// Output a dictionary contained in the model. + #[structopt(long)] + dump_dict: Option, + + /// Replace a dictionary if the argument is specified. + #[structopt(long)] + replace_dict: Option, +} + +#[derive(Deserialize, Serialize)] +struct WordWeightRecordFlatten { + word: String, + right: i32, + inside: i32, + left: i32, + comment: String, +} + +fn main() -> Result<(), Box> { + let opt = Opt::from_args(); + + eprintln!("Loading model file..."); + let mut f = zstd::Decoder::new(fs::File::open(opt.model_in)?)?; + let mut model = Model::read(&mut f)?; + + if let Some(path) = opt.dump_dict { + eprintln!("Saving dictionary file..."); + let file = fs::File::create(path)?; + let mut wtr = csv::Writer::from_writer(file); + for data in model.dictionary() { + wtr.serialize(WordWeightRecordFlatten { + word: data.get_word().to_string(), + right: data.get_right_weight(), + inside: data.get_inside_weight(), + left: data.get_left_weight(), + comment: data.get_comment().to_string(), + })?; + } + } + + if let Some(path) = opt.replace_dict { + eprintln!("Loading dictionary file..."); + let file = fs::File::open(path)?; + let mut rdr = csv::Reader::from_reader(file); + let mut dict = vec![]; + for result in rdr.deserialize() { + let record: WordWeightRecordFlatten = result?; + dict.push(WordWeightRecord::new( + record.word, + record.right, + record.inside, + record.left, + record.comment, + )); + } + model.replace_dictionary(dict); + } + + if let Some(path) = opt.model_out { + eprintln!("Saving model file..."); + let mut f = zstd::Encoder::new(fs::File::create(path)?, 19)?; + model.write(&mut f)?; + f.finish()?; + } + + Ok(()) +} diff --git a/model/model.zstd b/model/model.zstd deleted file mode 100644 index 8d409665..00000000 Binary files a/model/model.zstd and /dev/null differ diff --git a/predict/src/main.rs b/predict/src/main.rs index e6210f11..a6d4f9be 100644 --- a/predict/src/main.rs +++ b/predict/src/main.rs @@ -1,11 +1,12 @@ use std::fs::File; use std::io::{prelude::*, stdin}; use std::path::PathBuf; +use std::rc::Rc; use std::str::FromStr; use std::time::Instant; use structopt::StructOpt; -use vaporetto::{CharacterType, Model, Predictor, Sentence}; +use vaporetto::{errors::VaporettoError, CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter}, string_filters::KyteaFullwidthFilter, @@ -41,26 +42,80 @@ struct Opt { #[structopt(long)] model: PathBuf, + /// Predicts POS tags. + #[structopt(long)] + predict_tags: bool, + /// Do not segment some character types: {D, R, H, T, K, O, G}. /// D: Digit, R: Roman, H: Hiragana, T: Katakana, K: Kanji, O: Other, G: Grapheme cluster. #[structopt(long)] wsconst: Vec, + /// Prints scores. + #[structopt(long)] + scores: bool, + /// Do not normalize input strings before prediction. #[structopt(long)] no_norm: bool, } +fn print_scores(s: &Sentence) { + if !s.boundary_scores().is_empty() { + for (i, score) in s.boundary_scores().iter().enumerate() { + println!("{}:{}{} {}", i, s.chars()[i], s.chars()[i + 1], score); + } + println!(); + } +} + +fn tokenize( + predictor: &Predictor, + text: impl Into, + mut buf1: Sentence, + mut buf2: Sentence, + pre_filters: &[Box], + post_filters: &[Box], +) -> Result<(String, Sentence, Sentence), VaporettoError> { + let text = text.into(); + if pre_filters.is_empty() { + buf1.update_raw(text)?; + } else { + let text_rc = Rc::new(text); + let filt_text = Rc::try_unwrap( + pre_filters + .iter() + .fold(Rc::clone(&text_rc), |s, filter| Rc::new(filter.filter(&s))), + ) + .unwrap(); + let text = Rc::try_unwrap(text_rc).unwrap(); + buf1.update_raw(filt_text)?; + buf2.update_raw(text)?; + } + buf1 = predictor.predict_with_score(buf1); + buf1 = post_filters.iter().fold(buf1, |s, filter| filter.filter(s)); + buf1 = predictor.fill_tags(buf1); + let result = if pre_filters.is_empty() { + buf1.to_tokenized_string()? + } else { + buf2.boundaries_mut().copy_from_slice(buf1.boundaries()); + buf2.tags_mut().clone_from_slice(buf1.tags()); + buf2.to_tokenized_string()? + }; + Ok((result, buf1, buf2)) +} + fn main() -> Result<(), Box> { let opt = Opt::from_args(); - let fullwidth_filter = KyteaFullwidthFilter::new(); + let mut pre_filters: Vec> = vec![]; + if !opt.no_norm { + pre_filters.push(Box::new(KyteaFullwidthFilter)); + } let mut post_filters: Vec> = vec![]; for wsconst in &opt.wsconst { match wsconst { - WsConst::GraphemeCluster => { - post_filters.push(Box::new(ConcatGraphemeClustersFilter::new())) - } + WsConst::GraphemeCluster => post_filters.push(Box::new(ConcatGraphemeClustersFilter)), WsConst::CharType(char_type) => { post_filters.push(Box::new(KyteaWsConstFilter::new(*char_type))) } @@ -70,34 +125,34 @@ fn main() -> Result<(), Box> { eprintln!("Loading model file..."); let mut f = zstd::Decoder::new(File::open(opt.model)?)?; let model = Model::read(&mut f)?; - let predictor = Predictor::new(model); + let predictor = Predictor::new(model, opt.predict_tags)?; eprintln!("Start tokenization"); - let mut n_boundaries = 0; + let mut n_chars = 0; let start = Instant::now(); + let mut buf1 = Sentence::from_raw(" ")?; + let mut buf2 = Sentence::from_raw(" ")?; for line in stdin().lock().lines() { let line = line?; - let s = if opt.no_norm { - let s = Sentence::from_raw(line)?; - predictor.predict(s) - } else { - let norm = fullwidth_filter.filter(&line); - let mut s_orig = Sentence::from_raw(line)?; - let s = Sentence::from_raw(norm)?; - let s = predictor.predict(s); - s_orig.boundaries_mut().clone_from_slice(s.boundaries()); - s_orig - }; - let s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); - n_boundaries += s.boundaries().len(); - let toks = s.to_tokenized_string()?; - println!("{}", toks); + if line.is_empty() { + println!(); + continue; + } + let ret = tokenize(&predictor, line, buf1, buf2, &pre_filters, &post_filters)?; + let result = ret.0; + buf1 = ret.1; + buf2 = ret.2; + println!("{}", result); + if opt.scores { + print_scores(&buf1); + } + n_chars += buf1.chars().len(); } let duration = start.elapsed(); eprintln!("Elapsed: {} [sec]", duration.as_secs_f64()); eprintln!( - "Speed: {} [boundaries/sec]", - n_boundaries as f64 / duration.as_secs_f64() + "Speed: {} [chars/sec]", + n_chars as f64 / duration.as_secs_f64() ); Ok(()) diff --git a/train/src/main.rs b/train/src/main.rs index 76c6c590..6d60f579 100644 --- a/train/src/main.rs +++ b/train/src/main.rs @@ -4,7 +4,7 @@ use std::io::{prelude::*, stderr, BufReader}; use std::path::PathBuf; use structopt::{clap::ArgGroup, StructOpt}; -use vaporetto::{Dataset, Sentence, SolverType, Trainer}; +use vaporetto::{Sentence, SolverType, Trainer}; use vaporetto_rules::{string_filters::KyteaFullwidthFilter, StringFilter}; #[derive(StructOpt, Debug)] @@ -58,10 +58,6 @@ struct Opt { #[structopt(long, default_value = "1.0")] cost: f64, - /// Whether to use a bias value in classifier training - #[structopt(long)] - no_bias: bool, - /// The solver. {0, 1, 2, 3, 4, 5, 6, 7} (see LIBLINEAR documentation for more details) #[structopt(long, default_value = "1")] solver: SolverType, @@ -74,7 +70,7 @@ struct Opt { fn main() -> Result<(), Box> { let opt = Opt::from_args(); - let fullwidth_filter = KyteaFullwidthFilter::new(); + let fullwidth_filter = KyteaFullwidthFilter; eprintln!("Loading dataset..."); let mut train_sents = vec![]; @@ -95,6 +91,7 @@ fn main() -> Result<(), Box> { let new_line = fullwidth_filter.filter(s.to_raw_string()); let mut new_s = Sentence::from_raw(new_line)?; new_s.boundaries_mut().clone_from_slice(s.boundaries()); + new_s.tags_mut().clone_from_slice(s.tags()); new_s }; train_sents.push(s); @@ -116,7 +113,8 @@ fn main() -> Result<(), Box> { } else { let new_line = fullwidth_filter.filter(s.to_raw_string()); let mut new_s = Sentence::from_raw(new_line)?; - new_s.boundaries_mut().clone_from_slice(s.boundaries()); + new_s.boundaries_mut().copy_from_slice(s.boundaries()); + new_s.tags_mut().clone_from_slice(s.tags()); new_s }; train_sents.push(s); @@ -138,7 +136,7 @@ fn main() -> Result<(), Box> { let line = if opt.no_norm { line } else { - fullwidth_filter.filter(line) + fullwidth_filter.filter(&line) }; dictionary.insert(line); } @@ -147,21 +145,28 @@ fn main() -> Result<(), Box> { let dictionary: Vec = dictionary.into_iter().collect(); eprintln!("Extracting into features..."); - let mut dataset = Dataset::new( + let mut trainer = Trainer::new( opt.charn, opt.charw, opt.typen, opt.typew, dictionary, opt.dictn, )?; for (i, s) in train_sents.iter().enumerate() { if i % 10000 == 0 { - eprint!("# of features: {}\r", dataset.n_features()); + eprint!( + "# of features: {}, # of tag features: {}\r", + trainer.n_features(), + trainer.n_tag_features() + ); stderr().flush()?; } - dataset.push_sentence(s); + trainer.push_sentence(s)?; } - eprintln!("# of features: {}", dataset.n_features()); + eprintln!( + "# of features: {}, # of tag features: {}", + trainer.n_features(), + trainer.n_tag_features() + ); eprintln!("Start training..."); - let trainer = Trainer::new(opt.eps, opt.cost, if opt.no_bias { 0. } else { 1. }); - let model = trainer.train(dataset, opt.solver)?; + let model = trainer.train(opt.eps, opt.cost, opt.solver)?; eprintln!("Finish training."); let mut f = zstd::Encoder::new(File::create(opt.model)?, 19)?; diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index 18f1baf0..82bb33df 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vaporetto" -version = "0.2.0" -edition = "2018" +version = "0.3.0" +edition = "2021" authors = ["Koichi Akabe "] description = "Vaporetto: a pointwise prediction based tokenizer" license = "MIT OR Apache-2.0" @@ -10,24 +10,17 @@ repository = "https://github.com/legalforce-research/vaporetto" readme = "README.md" keywords = ["japanese", "analyzer", "tokenizer", "morphological"] categories = ["text-processing"] -autotests = false [dependencies] -anyhow = "1.0" # MIT or Apache-2.0 -bincode = "1.3.3" # MIT -daachorse = "0.2.0" # MIT or Apache-2.0 -serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0 +daachorse = "0.4.0" # MIT or Apache-2.0 -byteorder = { version = "1.4", optional = true } # Unlicense or MIT -crossbeam-channel = { version = "0.5", optional = true } # MIT or Apache-2.0 liblinear = { version = "1", optional = true } # MIT [features] -default = ["model-quantize"] -kytea = ["byteorder"] -model-quantize = [] -multithreading = ["crossbeam-channel"] +default = [] +kytea = [] train = ["liblinear"] +portable-simd = [] [package.metadata.docs.rs] all-features = true diff --git a/vaporetto/README.md b/vaporetto/README.md index 9e309661..7b40d926 100644 --- a/vaporetto/README.md +++ b/vaporetto/README.md @@ -14,14 +14,20 @@ let mut f = BufReader::new(File::open("model.raw").unwrap()); let model = Model::read(&mut f).unwrap(); let predictor = Predictor::new(model); -for line in stdin().lock().lines() { - let s = Sentence::from_raw(line.unwrap()).unwrap(); - let s = predictor.predict(s); - let toks = s.to_tokenized_string().unwrap(); - println!("{}", toks); -} +let s = Sentence::from_raw("火星猫の生態").unwrap(); +let s = predictor.predict(s); + +println!("{:?}", s.to_tokenized_vec().unwrap()); +// ["火星", "猫", "の", "生態"] ``` +## Feature flags + +* `kytea` - Enables the reader for models generated by KyTea. +* `train` - Enables the trainer. +* `portable-simd` - Uses the [portable SIMD API](https://github.com/rust-lang/portable-simd) instead + of our SIMD-conscious data layout. (Nightly Rust is required.) + ## License Licensed under either of diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs new file mode 100644 index 00000000..e5bc3f92 --- /dev/null +++ b/vaporetto/src/char_scorer.rs @@ -0,0 +1,352 @@ +use std::iter; +use std::sync::Arc; + +use daachorse::DoubleArrayAhoCorasick; + +use crate::dict_model::DictModel; +use crate::errors::{Result, VaporettoError}; +use crate::ngram_model::NgramModel; +use crate::sentence::{Sentence, TagRangeScore, TagRangeScores, TagScores}; +use crate::utils::{self, AddWeight, MergableWeight, WeightMerger}; + +#[cfg(feature = "portable-simd")] +use std::simd::i32x8; + +pub const SIMD_SIZE: usize = 8; +#[cfg(feature = "portable-simd")] +type I32Vec = i32x8; + +#[derive(Clone)] +struct PositionalWeight { + pub offset: i32, + pub weight: W, +} + +type NaivePositionalWeight = PositionalWeight>; + +impl NaivePositionalWeight { + fn new(offset: i32, weight: Vec) -> Self { + Self { offset, weight } + } +} + +impl MergableWeight for NaivePositionalWeight { + fn from_two_weights(weight1: &Self, weight2: &Self, n_classes: usize) -> Self { + debug_assert!(n_classes != 0); + let (weight1, weight2) = if weight1.offset > weight2.offset { + (weight2, weight1) + } else { + (weight1, weight2) + }; + let shift = (weight2.offset - weight1.offset) as usize * n_classes; + let mut weight = vec![0; weight1.weight.len().max(shift + weight2.weight.len())]; + weight[..weight1.weight.len()].copy_from_slice(&weight1.weight); + for (r, w2) in weight[shift..].iter_mut().zip(&weight2.weight) { + *r += w2; + } + Self { + offset: weight1.offset, + weight, + } + } +} + +#[derive(Clone)] +enum WeightVector { + Array(Vec), + + #[cfg(not(feature = "portable-simd"))] + Simd([i32; SIMD_SIZE]), + #[cfg(feature = "portable-simd")] + Simd(I32Vec), +} + +impl WeightVector { + pub fn new(weight: Vec) -> Self { + if weight.len() <= SIMD_SIZE { + let mut s = [0i32; SIMD_SIZE]; + s[..weight.len()].copy_from_slice(weight.as_slice()); + #[cfg(not(feature = "portable-simd"))] + { + Self::Simd(s) + } + #[cfg(feature = "portable-simd")] + { + Self::Simd(I32Vec::from_array(s)) + } + } else { + Self::Array(weight) + } + } +} + +impl AddWeight for WeightVector { + fn add_weight(&self, ys: &mut [i32], offset: isize) { + match self { + WeightVector::Array(weight) => { + weight.add_weight(ys, offset); + } + WeightVector::Simd(weight) => { + let ys_slice = &mut ys[offset as usize..offset as usize + SIMD_SIZE]; + #[cfg(feature = "portable-simd")] + { + let mut target = I32Vec::from_slice(ys_slice); + target += weight; + ys_slice.copy_from_slice(target.as_array()); + } + #[cfg(not(feature = "portable-simd"))] + for (y, w) in ys_slice.iter_mut().zip(weight) { + *y += w; + } + } + } + } +} + +pub struct WeightSet +where + W: Clone, +{ + boundary: Option>, + tag_left: Option>>, + tag_right: Option>>, + tag_self: Option, +} + +type NaiveWeightSet = WeightSet>; + +impl NaiveWeightSet { + fn boundary_weight(offset: i32, weight: Vec) -> Self { + Self { + boundary: Some(PositionalWeight::new(offset, weight)), + tag_left: None, + tag_right: None, + tag_self: None, + } + } + + fn tag_left_weight(offset: i32, weight: Vec) -> Self { + Self { + boundary: None, + tag_left: Some(PositionalWeight::new(offset, weight)), + tag_right: None, + tag_self: None, + } + } + + fn tag_right_weight(offset: i32, weight: Vec) -> Self { + Self { + boundary: None, + tag_left: None, + tag_right: Some(PositionalWeight::new(offset, weight)), + tag_self: None, + } + } + + fn tag_self_weight(start_rel_position: i32, weight: Vec) -> Self { + Self { + boundary: None, + tag_left: None, + tag_right: None, + tag_self: Some(Arc::new(vec![TagRangeScore::new( + start_rel_position, + weight, + )])), + } + } +} + +impl MergableWeight for NaiveWeightSet { + fn from_two_weights(weight1: &Self, weight2: &Self, n_classes: usize) -> Self { + Self { + boundary: utils::xor_or_zip_with(&weight1.boundary, &weight2.boundary, |w1, w2| { + PositionalWeight::from_two_weights(w1, w2, 1) + }), + tag_left: utils::xor_or_zip_with(&weight1.tag_left, &weight2.tag_left, |w1, w2| { + PositionalWeight::from_two_weights(w1, w2, n_classes) + }), + tag_right: utils::xor_or_zip_with(&weight1.tag_right, &weight2.tag_right, |w1, w2| { + PositionalWeight::from_two_weights(w1, w2, n_classes) + }), + tag_self: utils::xor_or_zip_with(&weight1.tag_self, &weight2.tag_self, |w1, w2| { + let mut w = w1.to_vec(); + w.append(&mut w2.to_vec()); + Arc::new(w) + }), + } + } +} + +pub struct CharScorer { + pma: DoubleArrayAhoCorasick, + weights: Vec>, +} + +impl CharScorer { + pub fn new(model: NgramModel, window_size: usize, dict: DictModel) -> Result { + let mut weight_merger = WeightMerger::new(1); + + for d in model.data { + let weight = PositionalWeight::new(-(window_size as i32) - 1, d.weights); + weight_merger.add(&d.ngram, weight); + } + for d in dict.dict { + let word_len = d.word.chars().count(); + let mut weight = Vec::with_capacity(word_len + 1); + weight.push(d.weights.right); + weight.resize(word_len, d.weights.inside); + weight.push(d.weights.left); + let weight = PositionalWeight::new(-(word_len as i32) - 1, weight); + weight_merger.add(&d.word, weight); + } + + let mut ngrams = vec![]; + let mut weights = vec![]; + for (ngram, data) in weight_merger.merge() { + ngrams.push(ngram); + let PositionalWeight { offset, weight } = data; + weights.push(PositionalWeight { + offset, + weight: WeightVector::new(weight), + }); + } + let pma = DoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("invalid character n-grams"))?; + Ok(Self { pma, weights }) + } + + pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { + // If the following assertion fails, Vaporetto has a bug. + assert_eq!(sentence.str_to_char_pos.len(), sentence.text.len() + 1); + + for m in self.pma.find_overlapping_no_suffix_iter(&sentence.text) { + // This was checked outside of the iteration. + let m_end = unsafe { *sentence.str_to_char_pos.get_unchecked(m.end()) }; + // Both the weights and the PMA always have the same number of items. + // Therefore, the following code is safe. + let pos_weights = unsafe { self.weights.get_unchecked(m.value()) }; + + let offset = padding as isize + m_end as isize + pos_weights.offset as isize; + pos_weights.weight.add_weight(ys, offset); + } + } +} + +pub struct CharScorerWithTags { + pma: DoubleArrayAhoCorasick, + weights: Vec>, + n_tags: usize, +} + +impl CharScorerWithTags { + pub fn new( + model: NgramModel, + window_size: usize, + dict: DictModel, + n_tags: usize, + tag_left_model: NgramModel, + tag_right_model: NgramModel, + tag_self_model: NgramModel, + ) -> Result { + let mut weight_merger = WeightMerger::new(n_tags); + + for d in model.data { + let weight = WeightSet::boundary_weight(-(window_size as i32), d.weights); + weight_merger.add(&d.ngram, weight); + } + for d in dict.dict { + let word_len = d.word.chars().count(); + let mut weight = Vec::with_capacity(word_len + 1); + weight.push(d.weights.right); + weight.resize(word_len, d.weights.inside); + weight.push(d.weights.left); + let weight = WeightSet::boundary_weight(-(word_len as i32), weight); + weight_merger.add(&d.word, weight); + } + for d in tag_left_model.data { + let weight = + WeightSet::tag_left_weight(-(d.ngram.chars().count() as i32) + 1, d.weights); + weight_merger.add(&d.ngram, weight); + } + for d in tag_right_model.data { + let weight = WeightSet::tag_right_weight(-(window_size as i32) - 1, d.weights); + weight_merger.add(&d.ngram, weight); + } + for d in tag_self_model.data { + let weight = WeightSet::tag_self_weight(-(d.ngram.chars().count() as i32), d.weights); + weight_merger.add(&d.ngram, weight); + } + + let mut ngrams = vec![]; + let mut weights = vec![]; + for (ngram, data) in weight_merger.merge() { + ngrams.push(ngram); + let WeightSet { + boundary, + tag_left, + tag_right, + tag_self, + } = data; + weights.push(WeightSet { + boundary: boundary.map(|PositionalWeight { offset, weight }| PositionalWeight { + offset, + weight: WeightVector::new(weight), + }), + tag_left, + tag_right, + tag_self, + }); + } + let pma = DoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("invalid character n-grams"))?; + Ok(Self { + pma, + weights, + n_tags, + }) + } + + pub fn add_scores( + &self, + sentence: &Sentence, + padding: usize, + ys: &mut [i32], + tag_ys: &mut TagScores, + ) { + for m in self.pma.find_overlapping_no_suffix_iter_from_iter( + iter::once(0) + .chain(sentence.text.as_bytes().iter().cloned()) + .chain(iter::once(0)), + ) { + let m_end = sentence + .str_to_char_pos + .get(m.end() - 1) + .copied() + .unwrap_or(sentence.chars.len() + 1); + + // Both the weights and the PMA always have the same number of items. + // Therefore, the following code is safe. + let weight_set = unsafe { self.weights.get_unchecked(m.value()) }; + + if let Some(pos_weights) = weight_set.boundary.as_ref() { + let offset = padding as isize + m_end as isize + pos_weights.offset as isize - 1; + pos_weights.weight.add_weight(ys, offset); + } + if let Some(pos_weights) = weight_set.tag_left.as_ref() { + let offset = (m_end as isize + pos_weights.offset as isize) * self.n_tags as isize; + pos_weights + .weight + .add_weight(&mut tag_ys.left_scores, offset); + } + if let Some(pos_weights) = weight_set.tag_right.as_ref() { + let offset = (m_end as isize + pos_weights.offset as isize) * self.n_tags as isize; + pos_weights + .weight + .add_weight(&mut tag_ys.right_scores, offset); + } + if let Some(weight) = weight_set.tag_self.as_ref() { + tag_ys.self_scores[m_end - 1].replace(Arc::clone(weight)); + } + } + } +} diff --git a/vaporetto/src/dict_model.rs b/vaporetto/src/dict_model.rs new file mode 100644 index 00000000..f1a678c0 --- /dev/null +++ b/vaporetto/src/dict_model.rs @@ -0,0 +1,167 @@ +use std::io::{Read, Write}; +use std::mem; + +use crate::errors::Result; +use crate::utils; + +#[derive(Clone, Copy, Default)] +pub struct DictWeight { + pub right: i32, + pub inside: i32, + pub left: i32, +} + +impl DictWeight { + pub fn serialize(&self, mut wtr: W) -> Result + where + W: Write, + { + utils::write_i32(&mut wtr, self.right)?; + utils::write_i32(&mut wtr, self.inside)?; + utils::write_i32(&mut wtr, self.left)?; + Ok(mem::size_of::() * 3) + } + + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + Ok(Self { + right: utils::read_i32(&mut rdr)?, + inside: utils::read_i32(&mut rdr)?, + left: utils::read_i32(&mut rdr)?, + }) + } +} + +/// Record of weights for each word. +#[derive(Clone)] +pub struct WordWeightRecord { + pub(crate) word: String, + pub(crate) weights: DictWeight, + pub(crate) comment: String, +} + +impl WordWeightRecord { + pub fn serialize(&self, mut wtr: W) -> Result + where + W: Write, + { + let word_size = self.word.len(); + let comment_size = self.comment.len(); + utils::write_u32(&mut wtr, u32::try_from(word_size).unwrap())?; + utils::write_u32(&mut wtr, u32::try_from(comment_size).unwrap())?; + wtr.write_all(self.word.as_bytes())?; + wtr.write_all(self.comment.as_bytes())?; + let weights_size = self.weights.serialize(&mut wtr)?; + Ok(mem::size_of::() * 2 + word_size + weights_size + comment_size) + } + + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + let word_size = utils::read_u32(&mut rdr)?; + let comment_size = utils::read_u32(&mut rdr)?; + let mut word_bytes = vec![0; word_size.try_into().unwrap()]; + rdr.read_exact(&mut word_bytes)?; + let mut comment_bytes = vec![0; comment_size.try_into().unwrap()]; + rdr.read_exact(&mut comment_bytes)?; + Ok(Self { + word: String::from_utf8(word_bytes)?, + weights: DictWeight::deserialize(&mut rdr)?, + comment: String::from_utf8(comment_bytes)?, + }) + } +} + +impl WordWeightRecord { + /// Creates a new word weight record. + /// + /// # Arguments + /// + /// * `word` - A word. + /// * `right` - A weight of the boundary when the word is found at right. + /// * `inside` - A weight of the boundary when the word is overlapped on the boundary. + /// * `left` - A weight of the boundary when the word is found at left. + /// * `comment` - A comment that does not affect the behaviour. + /// + /// # Returns + /// + /// A new record. + pub const fn new(word: String, right: i32, inside: i32, left: i32, comment: String) -> Self { + Self { + word, + weights: DictWeight { + right, + inside, + left, + }, + comment, + } + } + + /// Gets a reference to the word. + pub fn get_word(&self) -> &str { + &self.word + } + + /// Gets a `right` weight. + pub const fn get_right_weight(&self) -> i32 { + self.weights.right + } + + /// Gets a `inside` weight. + pub const fn get_inside_weight(&self) -> i32 { + self.weights.inside + } + + /// Gets a `left` weight. + pub const fn get_left_weight(&self) -> i32 { + self.weights.left + } + + /// Gets a reference to the comment. + pub fn get_comment(&self) -> &str { + &self.comment + } +} + +pub struct DictModel { + pub(crate) dict: Vec, +} + +impl DictModel { + pub fn new(dict: Vec) -> Self { + Self { dict } + } + + pub fn dictionary(&self) -> &[WordWeightRecord] { + &self.dict + } + + pub fn serialize(&self, mut wtr: W) -> Result + where + W: Write, + { + let dict_size = self.dict.len(); + utils::write_u32(&mut wtr, dict_size.try_into().unwrap())?; + let mut total_size = mem::size_of::(); + for entry in &self.dict { + total_size += entry.serialize(&mut wtr)?; + } + Ok(total_size) + } + + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + let dict_size = utils::read_u32(&mut rdr)?; + let mut dict = Vec::with_capacity(dict_size.try_into().unwrap()); + for _ in 0..dict_size { + dict.push(WordWeightRecord::deserialize(&mut rdr)?); + } + Ok(Self { dict }) + } +} diff --git a/vaporetto/src/errors.rs b/vaporetto/src/errors.rs new file mode 100644 index 00000000..5597a8ed --- /dev/null +++ b/vaporetto/src/errors.rs @@ -0,0 +1,115 @@ +//! Definition of errors. + +use std::error::Error; +use std::fmt; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum VaporettoError { + InvalidModel(InvalidModelError), + InvalidSentence(InvalidSentenceError), + InvalidArgument(InvalidArgumentError), + IOError(std::io::Error), + UTF8Error(std::string::FromUtf8Error), +} + +impl VaporettoError { + pub(crate) fn invalid_model(msg: S) -> Self + where + S: Into, + { + Self::InvalidModel(InvalidModelError { msg: msg.into() }) + } + + pub(crate) fn invalid_sentence(msg: S) -> Self + where + S: Into, + { + Self::InvalidSentence(InvalidSentenceError { msg: msg.into() }) + } + + pub(crate) fn invalid_argument(arg: &'static str, msg: S) -> Self + where + S: Into, + { + Self::InvalidArgument(InvalidArgumentError { + arg, + msg: msg.into(), + }) + } +} + +impl fmt::Display for VaporettoError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::InvalidModel(e) => e.fmt(f), + Self::InvalidSentence(e) => e.fmt(f), + Self::InvalidArgument(e) => e.fmt(f), + Self::IOError(e) => e.fmt(f), + Self::UTF8Error(e) => e.fmt(f), + } + } +} + +impl Error for VaporettoError {} + +/// Error used when the model is invalid. +#[derive(Debug)] +pub struct InvalidModelError { + /// Error message. + pub(crate) msg: String, +} + +impl fmt::Display for InvalidModelError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "InvalidModelError: {}", self.msg) + } +} + +impl Error for InvalidModelError {} + +/// Error used when the sentence is invalid. +#[derive(Debug)] +pub struct InvalidSentenceError { + /// Error message. + pub(crate) msg: String, +} + +impl fmt::Display for InvalidSentenceError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "InvalidSentenceError: {}", self.msg) + } +} + +impl Error for InvalidSentenceError {} + +/// Error used when the argument is invalid. +#[derive(Debug)] +pub struct InvalidArgumentError { + /// Name of the argument. + pub(crate) arg: &'static str, + + /// Error message. + pub(crate) msg: String, +} + +impl fmt::Display for InvalidArgumentError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "InvalidArgumentError: {}: {}", self.arg, self.msg) + } +} + +impl Error for InvalidArgumentError {} + +impl From for VaporettoError { + fn from(error: std::io::Error) -> Self { + Self::IOError(error) + } +} + +impl From for VaporettoError { + fn from(error: std::string::FromUtf8Error) -> Self { + Self::UTF8Error(error) + } +} diff --git a/vaporetto/src/feature.rs b/vaporetto/src/feature.rs index 31a12f9a..f1f80bb1 100644 --- a/vaporetto/src/feature.rs +++ b/vaporetto/src/feature.rs @@ -1,383 +1,381 @@ -use crate::sentence::{BoundaryType, Sentence}; +use std::hash::Hash; +use std::sync::Arc; -use anyhow::{anyhow, Result}; use daachorse::DoubleArrayAhoCorasick; -#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)] -pub enum FeatureContent<'a> { - CharacterNgram(&'a str), - CharacterTypeNgram(&'a [u8]), - DictionaryWord(usize), +use crate::errors::{Result, VaporettoError}; +use crate::sentence::BoundaryType; +use crate::sentence::Sentence; + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct StringNgramFeature<'a> { + pub(crate) rel_position: isize, + pub(crate) ngram: &'a str, } -#[derive(Debug, PartialEq)] -pub struct FeatureSpan<'a> { - start: usize, - end: usize, - feature: FeatureContent<'a>, +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct BytesNgramFeature<'a> { + pub(crate) rel_position: isize, + pub(crate) ngram: &'a [u8], } #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct Feature<'a> { - pub(crate) rel_position: usize, - pub(crate) feature: FeatureContent<'a>, +pub enum DictionaryWordPosition { + Right, + Left, + Inside, +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct DictionaryWordFeature { + pub(crate) position: DictionaryWordPosition, + pub(crate) length: usize, +} + +#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)] +pub enum BoundaryFeature<'a> { + CharacterNgram(StringNgramFeature<'a>), + CharacterTypeNgram(BytesNgramFeature<'a>), + DictionaryWord(DictionaryWordFeature), +} + +impl<'a> BoundaryFeature<'a> { + pub const fn char_ngram(rel_position: isize, ngram: &'a str) -> Self { + Self::CharacterNgram(StringNgramFeature { + rel_position, + ngram, + }) + } + + pub const fn type_ngram(rel_position: isize, ngram: &'a [u8]) -> Self { + Self::CharacterTypeNgram(BytesNgramFeature { + rel_position, + ngram, + }) + } + + pub const fn dict_word(position: DictionaryWordPosition, length: usize) -> Self { + Self::DictionaryWord(DictionaryWordFeature { position, length }) + } } #[derive(Debug, PartialEq)] -pub struct Example<'a> { - pub features: Vec>, +pub struct BoundaryExample<'a> { + pub features: Vec>, pub label: BoundaryType, } -pub struct FeatureExtractor { +pub struct BoundaryExampleGenerator { char_ngram_size: usize, type_ngram_size: usize, - dict_ac: DoubleArrayAhoCorasick, - dict_word_size: Vec, + char_window_size: usize, + type_window_size: usize, + dict_ac: Option, + dict_max_word_size: usize, } -impl FeatureExtractor { - pub fn new( +impl BoundaryExampleGenerator { + pub fn new( char_ngram_size: usize, type_ngram_size: usize, - dictionary: D, - dict_word_max_size: usize, + char_window_size: usize, + type_window_size: usize, + dict: Option, + dict_max_word_size: usize, ) -> Result where - D: AsRef<[P]>, - P: AsRef<[u8]> + AsRef, + I: IntoIterator, + P: AsRef<[u8]>, { - let dictionary = dictionary.as_ref(); - let mut dict_word_size = Vec::with_capacity(dictionary.len()); - for word in dictionary { - let size = std::cmp::min( - AsRef::::as_ref(word).chars().count(), - dict_word_max_size, - ); - if size == 0 { - return Err(anyhow!("`dictionary` contains an empty string")); - } - dict_word_size.push(size); - } + let dict_ac = if let Some(dict) = dict { + Some( + DoubleArrayAhoCorasick::new(dict) + .map_err(|e| VaporettoError::invalid_argument("dict", format!("{:?}", e)))?, + ) + } else { + None + }; Ok(Self { char_ngram_size, type_ngram_size, - dict_ac: DoubleArrayAhoCorasick::new(dictionary).unwrap(), - dict_word_size, + char_window_size, + type_window_size, + dict_ac, + dict_max_word_size, }) } - pub fn extract<'a>(&self, sentence: &'a Sentence) -> Vec> { - let mut features = vec![]; - for n in 0..self.char_ngram_size as isize { - for i in 0..sentence.char_type.len() as isize - n { - let start = i as usize; - let end = (i + n + 1) as usize; - let feature = FeatureContent::CharacterNgram(sentence.char_substring(start, end)); - features.push(FeatureSpan { - start, - end, - feature, - }) + pub fn generate<'a>(&self, s: &'a Sentence) -> Vec> { + let mut result = vec![]; + for (i, &label) in s.boundaries().iter().enumerate() { + let mut features = vec![]; + for n in 1..self.char_ngram_size + 1 { + let begin = (i + 1).saturating_sub(self.char_window_size); + let end = (i + 1 + self.char_window_size) + .min(s.chars.len()) + .saturating_sub(n - 1); + for pos in begin..end { + let rel_position = pos as isize - i as isize - 1; + let ngram = s.char_substring(pos, pos + n); + features.push(BoundaryFeature::char_ngram(rel_position, ngram)); + } } - } - for n in 0..self.type_ngram_size as isize { - for i in 0..sentence.char_type.len() as isize - n { - let start = i as usize; - let end = (i + n + 1) as usize; - let feature = - FeatureContent::CharacterTypeNgram(sentence.type_substring(start, end)); - features.push(FeatureSpan { - start, - end, - feature, - }); + for n in 1..self.type_ngram_size + 1 { + let begin = (i + 1).saturating_sub(self.type_window_size); + let end = (i + 1 + self.type_window_size) + .min(s.chars.len()) + .saturating_sub(n - 1); + for pos in begin..end { + let rel_position = pos as isize - i as isize - 1; + let ngram = &s.char_types()[pos..pos + n]; + features.push(BoundaryFeature::type_ngram(rel_position, ngram)); + } } + result.push(BoundaryExample { features, label }) } - for m in self.dict_ac.find_overlapping_iter(&sentence.text) { - let start = sentence.str_to_char_pos[m.start()]; - let end = sentence.str_to_char_pos[m.end()]; - let feature = FeatureContent::DictionaryWord(self.dict_word_size[m.pattern()]); - features.push(FeatureSpan { - start, - end, - feature, - }); + if let Some(dict_ac) = self.dict_ac.as_ref() { + for m in dict_ac.find_overlapping_iter(&s.text) { + let m_start = s.str_to_char_pos[m.start()]; + let m_end = s.str_to_char_pos[m.end()]; + let length = (m_end - m_start).min(self.dict_max_word_size); + if m_start != 0 { + result[m_start - 1] + .features + .push(BoundaryFeature::dict_word( + DictionaryWordPosition::Right, + length, + )); + } + for example in &mut result[m_start..m_end - 1] { + example.features.push(BoundaryFeature::dict_word( + DictionaryWordPosition::Inside, + length, + )); + } + if m_end != s.chars().len() { + result[m_end - 1].features.push(BoundaryFeature::dict_word( + DictionaryWordPosition::Left, + length, + )); + } + } } - features + result + .into_iter() + .filter(|example| example.label != BoundaryType::Unknown) + .collect() } } -pub struct ExampleGenerator { +#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum TagFeature<'a> { + LeftCharacterNgram(StringNgramFeature<'a>), + LeftCharacterNgramBos(StringNgramFeature<'a>), + RightCharacterNgram(StringNgramFeature<'a>), + RightCharacterNgramEos(StringNgramFeature<'a>), + Character(&'a str), +} + +impl<'a> TagFeature<'a> { + pub const fn left_char_ngram(rel_position: isize, ngram: &'a str) -> Self { + Self::LeftCharacterNgram(StringNgramFeature { + rel_position, + ngram, + }) + } + + pub const fn left_char_ngram_bos(rel_position: isize, ngram: &'a str) -> Self { + Self::LeftCharacterNgramBos(StringNgramFeature { + rel_position, + ngram, + }) + } + + pub const fn right_char_ngram(rel_position: isize, ngram: &'a str) -> Self { + Self::RightCharacterNgram(StringNgramFeature { + rel_position, + ngram, + }) + } + + pub const fn right_char_ngram_eos(rel_position: isize, ngram: &'a str) -> Self { + Self::RightCharacterNgramEos(StringNgramFeature { + rel_position, + ngram, + }) + } + + pub const fn chars(chars: &'a str) -> Self { + Self::Character(chars) + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct TagExample<'a> { + pub features: Vec>, + pub tag: Arc, +} + +pub struct TagExampleGenerator { + char_ngram_size: usize, char_window_size: usize, - type_window_size: usize, } -impl ExampleGenerator { - pub const fn new(char_window_size: usize, type_window_size: usize) -> Self { +impl TagExampleGenerator { + pub const fn new(char_ngram_size: usize, char_window_size: usize) -> Self { Self { + char_ngram_size, char_window_size, - type_window_size, } } - pub fn generate<'a>( - &self, - sentence: &'a Sentence, - feature_spans: impl Into>>, - include_unknown: bool, - ) -> Vec> { - let mut examples: Vec = sentence - .boundaries + pub fn generate<'a>(&self, sentence: &'a Sentence) -> Result>> { + let mut result = vec![]; + let mut features = vec![]; + for start in (sentence.chars.len() + 1).saturating_sub(self.char_ngram_size) + ..sentence.chars.len() + 1 + { + features.push(TagFeature::right_char_ngram_eos( + 1, + sentence.char_substring(start, sentence.chars.len()), + )); + } + let mut current_tag: Option> = sentence + .tags + .last() + .and_then(|x| x.as_ref()) + .map(Arc::clone); + let mut tag_right_pos = sentence.chars.len(); + for (i, (t, b)) in sentence + .tags .iter() - .map(|&label| Example { - features: vec![], - label, - }) - .collect(); - for span in feature_spans.into() { - match span.feature { - FeatureContent::CharacterNgram(_) => { - let start = - std::cmp::max(span.end - 1, self.char_window_size) - self.char_window_size; - let end = std::cmp::min( - span.start + self.char_window_size, - sentence.boundaries.len(), - ); - for (i, example) in examples.iter_mut().enumerate().take(end).skip(start) { - if include_unknown || example.label != BoundaryType::Unknown { - example.features.push(Feature { - rel_position: self.char_window_size + i + 1 - span.end, - feature: span.feature, - }); + .zip(sentence.boundaries()) + .enumerate() + .rev() + { + match b { + BoundaryType::WordBoundary => { + if let Some(tag) = current_tag.take() { + if i + 2 <= self.char_window_size { + let rel_position = -(i as isize) - 2; + for end in 0..sentence.chars.len().min(self.char_ngram_size) { + features.push(TagFeature::left_char_ngram_bos( + rel_position, + sentence.char_substring(0, end), + )); + } } - } - } - FeatureContent::CharacterTypeNgram(_) => { - let start = - std::cmp::max(span.end - 1, self.type_window_size) - self.type_window_size; - let end = std::cmp::min( - span.start + self.type_window_size, - sentence.boundaries.len(), - ); - for (i, example) in examples.iter_mut().enumerate().take(end).skip(start) { - if include_unknown || example.label != BoundaryType::Unknown { - example.features.push(Feature { - rel_position: self.type_window_size + i + 1 - span.end, - feature: span.feature, - }); + for j in (i + 1).saturating_sub(self.char_window_size)..i + 1 { + let rel_position = j as isize - i as isize - 1; + for end in j + 1..sentence.chars.len().min(j + self.char_ngram_size) + 1 + { + features.push(TagFeature::left_char_ngram( + rel_position, + sentence.char_substring(j, end), + )); + } } + features.push(TagFeature::chars( + sentence.char_substring(i + 1, tag_right_pos), + )); + result.push(TagExample { features, tag }); + features = vec![]; } - } - FeatureContent::DictionaryWord(_) => { - if span.start >= 1 { - let example = &mut examples[span.start - 1]; - if include_unknown || example.label != BoundaryType::Unknown { - example.features.push(Feature { - rel_position: 0, - feature: span.feature, - }); + if let Some(tag) = t.as_ref() { + current_tag.replace(Arc::clone(tag)); + tag_right_pos = i + 1; + for j in + (i + 2)..(i + 2 + self.char_window_size).min(sentence.chars.len() + 1) + { + let rel_position = j as isize - i as isize - 1; + for start in j.saturating_sub(self.char_ngram_size)..j { + features.push(TagFeature::right_char_ngram( + rel_position, + sentence.char_substring(start, j), + )); + } } - } - for example in &mut examples[span.start..span.end - 1] { - if include_unknown || example.label != BoundaryType::Unknown { - example.features.push(Feature { - rel_position: 1, - feature: span.feature, - }); + if i + self.char_window_size >= sentence.chars.len() { + let rel_position = sentence.chars.len() as isize - i as isize; + for start in (sentence.chars.len() + 1) + .saturating_sub(self.char_ngram_size) + ..sentence.chars.len() + 1 + { + features.push(TagFeature::right_char_ngram_eos( + rel_position, + sentence.char_substring(start, sentence.chars.len()), + )); + } } } - if span.end <= examples.len() { - let example = &mut examples[span.end - 1]; - if include_unknown || example.label != BoundaryType::Unknown { - example.features.push(Feature { - rel_position: 2, - feature: span.feature, - }); - } + } + BoundaryType::NotWordBoundary => (), + BoundaryType::Unknown => { + if current_tag.is_some() { + return Err(VaporettoError::invalid_argument("sentence", "")); } } } } - if include_unknown { - examples - } else { - examples - .into_iter() - .filter(|example| example.label != BoundaryType::Unknown) - .collect() + if let Some(tag) = current_tag.take() { + for end in 0..sentence.chars.len().min(self.char_ngram_size) { + features.push(TagFeature::left_char_ngram_bos( + -1, + sentence.char_substring(0, end), + )); + } + features.push(TagFeature::chars(sentence.char_substring(0, tag_right_pos))); + result.push(TagExample { features, tag }); } + Ok(result) } } #[cfg(test)] mod tests { use super::*; - use crate::sentence::CharacterType::*; + use BoundaryFeature::*; use BoundaryType::*; - use FeatureContent::*; - - #[test] - fn test_feature_extractor_new_empty_dict_string() { - let dict = ["東京特許許可局", "", "猫"]; - let fe = FeatureExtractor::new(3, 2, dict, 4); - - assert!(fe.is_err()); - assert_eq!( - "`dictionary` contains an empty string", - &fe.err().unwrap().to_string() - ); - } - - #[test] - fn test_feature_extractor_new_empty_dict() { - let dict: &[String] = &[]; - let fe = FeatureExtractor::new(3, 2, dict, 4).unwrap(); - - assert_eq!(3, fe.char_ngram_size); - assert_eq!(2, fe.type_ngram_size); - assert_eq!(Vec::::new(), fe.dict_word_size); - } - - #[test] - fn test_feature_extractor_new() { - let dict = ["東京特許許可局", "火星猫", "猫"]; - let fe = FeatureExtractor::new(3, 2, dict, 4).unwrap(); - - assert_eq!(3, fe.char_ngram_size); - assert_eq!(2, fe.type_ngram_size); - assert_eq!(vec![4, 3, 1], fe.dict_word_size); - } - - #[test] - fn test_feature_extractor_extract_one() { - let dict = ["東京特許許可局", "火星猫", "猫"]; - let fe = FeatureExtractor::new(3, 2, dict, 4).unwrap(); - let s = Sentence::from_raw("A").unwrap(); - let feature_spans = fe.extract(&s); - - let expected = vec![ - FeatureSpan { - start: 0, - end: 1, - feature: CharacterNgram("A"), - }, - FeatureSpan { - start: 0, - end: 1, - feature: CharacterTypeNgram(&ct2u8![Roman]), - }, - ]; - assert_eq!(expected, feature_spans); - } - - #[test] - fn test_feature_extractor_extract() { - let dict = ["東京特許許可局", "火星猫", "猫"]; - let fe = FeatureExtractor::new(3, 2, dict, 2).unwrap(); - let s = Sentence::from_raw("Ariaは火星猫だ").unwrap(); - let feature_spans = fe.extract(&s); - - #[rustfmt::skip] - let expected = vec![ - FeatureSpan { start: 0, end: 1, feature: CharacterNgram("A") }, - FeatureSpan { start: 1, end: 2, feature: CharacterNgram("r") }, - FeatureSpan { start: 2, end: 3, feature: CharacterNgram("i") }, - FeatureSpan { start: 3, end: 4, feature: CharacterNgram("a") }, - FeatureSpan { start: 4, end: 5, feature: CharacterNgram("は") }, - FeatureSpan { start: 5, end: 6, feature: CharacterNgram("火") }, - FeatureSpan { start: 6, end: 7, feature: CharacterNgram("星") }, - FeatureSpan { start: 7, end: 8, feature: CharacterNgram("猫") }, - FeatureSpan { start: 8, end: 9, feature: CharacterNgram("だ") }, - FeatureSpan { start: 0, end: 2, feature: CharacterNgram("Ar") }, - FeatureSpan { start: 1, end: 3, feature: CharacterNgram("ri") }, - FeatureSpan { start: 2, end: 4, feature: CharacterNgram("ia") }, - FeatureSpan { start: 3, end: 5, feature: CharacterNgram("aは") }, - FeatureSpan { start: 4, end: 6, feature: CharacterNgram("は火") }, - FeatureSpan { start: 5, end: 7, feature: CharacterNgram("火星") }, - FeatureSpan { start: 6, end: 8, feature: CharacterNgram("星猫") }, - FeatureSpan { start: 7, end: 9, feature: CharacterNgram("猫だ") }, - FeatureSpan { start: 0, end: 3, feature: CharacterNgram("Ari") }, - FeatureSpan { start: 1, end: 4, feature: CharacterNgram("ria") }, - FeatureSpan { start: 2, end: 5, feature: CharacterNgram("iaは") }, - FeatureSpan { start: 3, end: 6, feature: CharacterNgram("aは火") }, - FeatureSpan { start: 4, end: 7, feature: CharacterNgram("は火星") }, - FeatureSpan { start: 5, end: 8, feature: CharacterNgram("火星猫") }, - FeatureSpan { start: 6, end: 9, feature: CharacterNgram("星猫だ") }, - FeatureSpan { start: 0, end: 1, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - FeatureSpan { start: 1, end: 2, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - FeatureSpan { start: 2, end: 3, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - FeatureSpan { start: 3, end: 4, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - FeatureSpan { start: 4, end: 5, feature: CharacterTypeNgram(&ct2u8![Hiragana]) }, - FeatureSpan { start: 5, end: 6, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - FeatureSpan { start: 6, end: 7, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - FeatureSpan { start: 7, end: 8, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - FeatureSpan { start: 8, end: 9, feature: CharacterTypeNgram(&ct2u8![Hiragana]) }, - FeatureSpan { start: 0, end: 2, feature: CharacterTypeNgram(&ct2u8![Roman, Roman]) }, - FeatureSpan { start: 1, end: 3, feature: CharacterTypeNgram(&ct2u8![Roman, Roman]) }, - FeatureSpan { start: 2, end: 4, feature: CharacterTypeNgram(&ct2u8![Roman, Roman]) }, - FeatureSpan { start: 3, end: 5, feature: CharacterTypeNgram(&ct2u8![Roman, Hiragana]) }, - FeatureSpan { start: 4, end: 6, feature: CharacterTypeNgram(&ct2u8![Hiragana, Kanji]) }, - FeatureSpan { start: 5, end: 7, feature: CharacterTypeNgram(&ct2u8![Kanji, Kanji]) }, - FeatureSpan { start: 6, end: 8, feature: CharacterTypeNgram(&ct2u8![Kanji, Kanji]) }, - FeatureSpan { start: 7, end: 9, feature: CharacterTypeNgram(&ct2u8![Kanji, Hiragana]) }, - FeatureSpan { start: 5, end: 8, feature: DictionaryWord(2) }, - FeatureSpan { start: 7, end: 8, feature: DictionaryWord(1) }, - ]; - assert_eq!(expected, feature_spans); - } - - #[test] - fn test_example_generator_new() { - let gen = ExampleGenerator::new(3, 2); - - assert_eq!(3, gen.char_window_size); - assert_eq!(2, gen.type_window_size); - } #[test] fn test_example_generator_generate_one() { - let dict = ["東京特許許可局", "火星猫", "猫"]; - let fe = FeatureExtractor::new(3, 2, dict, 2).unwrap(); - let gen = ExampleGenerator::new(3, 2); + let dict = Some(["東京特許許可局", "火星猫", "猫"]); + let gen = BoundaryExampleGenerator::new(3, 2, 3, 2, dict, 2).unwrap(); let s = Sentence::from_raw("猫").unwrap(); - let feature_spans = fe.extract(&s); - let examples = gen.generate(&s, feature_spans, true); + let examples = gen.generate(&s); - assert_eq!(Vec::::new(), examples); + assert!(examples.is_empty()); } #[test] fn test_example_generator_generate_all() { - let dict = ["東京特許許可局", "火星猫", "猫"]; - let fe = FeatureExtractor::new(3, 2, dict, 2).unwrap(); - let gen = ExampleGenerator::new(3, 2); + let dict = Some(["東京特許許可局", "火星猫", "猫"]); + let gen = BoundaryExampleGenerator::new(3, 2, 3, 2, dict, 2).unwrap(); let s = Sentence::from_partial_annotation("A-r-i-a|は|火-星 猫|だ").unwrap(); - let feature_spans = fe.extract(&s); - let examples = gen.generate(&s, feature_spans, true); + let examples = gen.generate(&s); - assert_eq!(8, examples.len()); + assert_eq!(7, examples.len()); // pos 3 "A-r" #[rustfmt::skip] - let expected = Example { + let expected = BoundaryExample { features: vec![ - Feature { rel_position: 3, feature: CharacterNgram("A") }, - Feature { rel_position: 2, feature: CharacterNgram("r") }, - Feature { rel_position: 1, feature: CharacterNgram("i") }, - Feature { rel_position: 0, feature: CharacterNgram("a") }, - Feature { rel_position: 2, feature: CharacterNgram("Ar") }, - Feature { rel_position: 1, feature: CharacterNgram("ri") }, - Feature { rel_position: 0, feature: CharacterNgram("ia") }, - Feature { rel_position: 1, feature: CharacterNgram("Ari") }, - Feature { rel_position: 0, feature: CharacterNgram("ria") }, - Feature { rel_position: 2, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - Feature { rel_position: 1, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - Feature { rel_position: 0, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - Feature { rel_position: 1, feature: CharacterTypeNgram(&ct2u8![Roman, Roman]) }, - Feature { rel_position: 0, feature: CharacterTypeNgram(&ct2u8![Roman, Roman]) }, + CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "A" }), + CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "r" }), + CharacterNgram(StringNgramFeature { rel_position: 1, ngram: "i" }), + CharacterNgram(StringNgramFeature { rel_position: 2, ngram: "a" }), + CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "Ar" }), + CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "ri" }), + CharacterNgram(StringNgramFeature { rel_position: 1, ngram: "ia" }), + CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "Ari" }), + CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "ria" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: b"R" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: b"R" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: 1, ngram: b"R" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: b"RR" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: b"RR" }), ], label: NotWordBoundary, }; @@ -385,101 +383,378 @@ mod tests { // pos 3 "a|は" #[rustfmt::skip] - let expected = Example { + let expected = BoundaryExample { features: vec![ - Feature { rel_position: 5, feature: CharacterNgram("r") }, - Feature { rel_position: 4, feature: CharacterNgram("i") }, - Feature { rel_position: 3, feature: CharacterNgram("a") }, - Feature { rel_position: 2, feature: CharacterNgram("は") }, - Feature { rel_position: 1, feature: CharacterNgram("火") }, - Feature { rel_position: 0, feature: CharacterNgram("星") }, - Feature { rel_position: 4, feature: CharacterNgram("ri") }, - Feature { rel_position: 3, feature: CharacterNgram("ia") }, - Feature { rel_position: 2, feature: CharacterNgram("aは") }, - Feature { rel_position: 1, feature: CharacterNgram("は火") }, - Feature { rel_position: 0, feature: CharacterNgram("火星") }, - Feature { rel_position: 3, feature: CharacterNgram("ria") }, - Feature { rel_position: 2, feature: CharacterNgram("iaは") }, - Feature { rel_position: 1, feature: CharacterNgram("aは火") }, - Feature { rel_position: 0, feature: CharacterNgram("は火星") }, - Feature { rel_position: 3, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - Feature { rel_position: 2, feature: CharacterTypeNgram(&ct2u8![Roman]) }, - Feature { rel_position: 1, feature: CharacterTypeNgram(&ct2u8![Hiragana]) }, - Feature { rel_position: 0, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - Feature { rel_position: 2, feature: CharacterTypeNgram(&ct2u8![Roman, Roman]) }, - Feature { rel_position: 1, feature: CharacterTypeNgram(&ct2u8![Roman, Hiragana]) }, - Feature { rel_position: 0, feature: CharacterTypeNgram(&ct2u8![Hiragana, Kanji]) }, + CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "r" }), + CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "i" }), + CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "a" }), + CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "は" }), + CharacterNgram(StringNgramFeature { rel_position: 1, ngram: "火" }), + CharacterNgram(StringNgramFeature { rel_position: 2, ngram: "星" }), + CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "ri" }), + CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "ia" }), + CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "aは" }), + CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "は火" }), + CharacterNgram(StringNgramFeature { rel_position: 1, ngram: "火星" }), + CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "ria" }), + CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "iaは" }), + CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "aは火" }), + CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "は火星" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -2, ngram: b"R" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: b"R" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: b"H" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: 1, ngram: b"K" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -2, ngram: b"RR" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: b"RH" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: b"HK" }), ], label: WordBoundary, }; assert_eq!(expected, examples[3]); - // pos 6 "星 猫" - #[rustfmt::skip] - let expected = Example { - features: vec![ - Feature { rel_position: 5, feature: CharacterNgram("は") }, - Feature { rel_position: 4, feature: CharacterNgram("火") }, - Feature { rel_position: 3, feature: CharacterNgram("星") }, - Feature { rel_position: 2, feature: CharacterNgram("猫") }, - Feature { rel_position: 1, feature: CharacterNgram("だ") }, - Feature { rel_position: 4, feature: CharacterNgram("は火") }, - Feature { rel_position: 3, feature: CharacterNgram("火星") }, - Feature { rel_position: 2, feature: CharacterNgram("星猫") }, - Feature { rel_position: 1, feature: CharacterNgram("猫だ") }, - Feature { rel_position: 3, feature: CharacterNgram("は火星") }, - Feature { rel_position: 2, feature: CharacterNgram("火星猫") }, - Feature { rel_position: 1, feature: CharacterNgram("星猫だ") }, - Feature { rel_position: 3, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - Feature { rel_position: 2, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - Feature { rel_position: 1, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - Feature { rel_position: 0, feature: CharacterTypeNgram(&ct2u8![Hiragana]) }, - Feature { rel_position: 2, feature: CharacterTypeNgram(&ct2u8![Kanji, Kanji]) }, - Feature { rel_position: 1, feature: CharacterTypeNgram(&ct2u8![Kanji, Kanji]) }, - Feature { rel_position: 0, feature: CharacterTypeNgram(&ct2u8![Kanji, Hiragana]) }, - Feature { rel_position: 1, feature: DictionaryWord(2) }, - Feature { rel_position: 0, feature: DictionaryWord(1) }, - ], - label: Unknown, - }; - assert_eq!(expected, examples[6]); + // pos 6 "星 猫" (skipped) // pos 7 "猫|だ" #[rustfmt::skip] - let expected = Example { + let expected = BoundaryExample { features: vec![ - Feature { rel_position: 5, feature: CharacterNgram("火") }, - Feature { rel_position: 4, feature: CharacterNgram("星") }, - Feature { rel_position: 3, feature: CharacterNgram("猫") }, - Feature { rel_position: 2, feature: CharacterNgram("だ") }, - Feature { rel_position: 4, feature: CharacterNgram("火星") }, - Feature { rel_position: 3, feature: CharacterNgram("星猫") }, - Feature { rel_position: 2, feature: CharacterNgram("猫だ") }, - Feature { rel_position: 3, feature: CharacterNgram("火星猫") }, - Feature { rel_position: 2, feature: CharacterNgram("星猫だ") }, - Feature { rel_position: 3, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - Feature { rel_position: 2, feature: CharacterTypeNgram(&ct2u8![Kanji]) }, - Feature { rel_position: 1, feature: CharacterTypeNgram(&ct2u8![Hiragana]) }, - Feature { rel_position: 2, feature: CharacterTypeNgram(&ct2u8![Kanji, Kanji]) }, - Feature { rel_position: 1, feature: CharacterTypeNgram(&ct2u8![Kanji, Hiragana]) }, - Feature { rel_position: 2, feature: DictionaryWord(2) }, - Feature { rel_position: 2, feature: DictionaryWord(1) }, + CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "火" }), + CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "星" }), + CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "猫" }), + CharacterNgram(StringNgramFeature { rel_position: 0, ngram: "だ" }), + CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "火星" }), + CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "星猫" }), + CharacterNgram(StringNgramFeature { rel_position: -1, ngram: "猫だ" }), + CharacterNgram(StringNgramFeature { rel_position: -3, ngram: "火星猫" }), + CharacterNgram(StringNgramFeature { rel_position: -2, ngram: "星猫だ" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -2, ngram: b"K" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: b"K" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: 0, ngram: b"H" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -2, ngram: b"KK" }), + CharacterTypeNgram(BytesNgramFeature { rel_position: -1, ngram: b"KH" }), + DictionaryWord(DictionaryWordFeature { position: DictionaryWordPosition::Left, length: 2 }), + DictionaryWord(DictionaryWordFeature { position: DictionaryWordPosition::Left, length: 1 }), ], label: WordBoundary, }; - assert_eq!(expected, examples[7]); + assert_eq!(expected, examples[6]); } #[test] fn test_example_generator_generate_without_unknown() { - let dict = ["東京特許許可局", "火星猫", "猫"]; - let fe = FeatureExtractor::new(3, 2, dict, 2).unwrap(); - let gen = ExampleGenerator::new(3, 2); + let dict = Some(["東京特許許可局", "火星猫", "猫"]); + let gen = BoundaryExampleGenerator::new(3, 2, 3, 2, dict, 2).unwrap(); let s = Sentence::from_partial_annotation("A-r-i-a|は|火-星 猫|だ").unwrap(); - let feature_spans = fe.extract(&s); - let examples = gen.generate(&s, feature_spans, false); + let examples = gen.generate(&s); assert_eq!(7, examples.len()); } + + #[test] + fn test_tag_example_generate_33() { + let gen = TagExampleGenerator::new(3, 3); + + let s = + Sentence::from_partial_annotation("A-r-i-a/名詞|は/助詞|火-星 猫|だ/助動詞").unwrap(); + let mut examples = gen.generate(&s).unwrap(); + + // The order of examples is unimportant. + examples + .iter_mut() + .for_each(|example| example.features.sort_unstable()); + examples.sort_unstable(); + + let mut expected = vec![ + TagExample { + features: vec![ + TagFeature::right_char_ngram(1, "iaは"), + TagFeature::right_char_ngram(1, "aは"), + TagFeature::right_char_ngram(1, "は"), + TagFeature::right_char_ngram(2, "aは火"), + TagFeature::right_char_ngram(2, "は火"), + TagFeature::right_char_ngram(2, "火"), + TagFeature::right_char_ngram(3, "は火星"), + TagFeature::right_char_ngram(3, "火星"), + TagFeature::right_char_ngram(3, "星"), + TagFeature::left_char_ngram_bos(-1, ""), + TagFeature::left_char_ngram_bos(-1, "A"), + TagFeature::left_char_ngram_bos(-1, "Ar"), + TagFeature::chars("Aria"), + ], + tag: Arc::new("名詞".to_string()), + }, + TagExample { + features: vec![ + TagFeature::right_char_ngram(1, "aは火"), + TagFeature::right_char_ngram(1, "は火"), + TagFeature::right_char_ngram(1, "火"), + TagFeature::right_char_ngram(2, "は火星"), + TagFeature::right_char_ngram(2, "火星"), + TagFeature::right_char_ngram(2, "星"), + TagFeature::right_char_ngram(3, "火星猫"), + TagFeature::right_char_ngram(3, "星猫"), + TagFeature::right_char_ngram(3, "猫"), + TagFeature::left_char_ngram(-3, "r"), + TagFeature::left_char_ngram(-3, "ri"), + TagFeature::left_char_ngram(-3, "ria"), + TagFeature::left_char_ngram(-2, "i"), + TagFeature::left_char_ngram(-2, "ia"), + TagFeature::left_char_ngram(-2, "iaは"), + TagFeature::left_char_ngram(-1, "a"), + TagFeature::left_char_ngram(-1, "aは"), + TagFeature::left_char_ngram(-1, "aは火"), + TagFeature::chars("は"), + ], + tag: Arc::new("助詞".to_string()), + }, + TagExample { + features: vec![ + TagFeature::right_char_ngram_eos(1, "猫だ"), + TagFeature::right_char_ngram_eos(1, "だ"), + TagFeature::right_char_ngram_eos(1, ""), + TagFeature::left_char_ngram(-3, "火"), + TagFeature::left_char_ngram(-3, "火星"), + TagFeature::left_char_ngram(-3, "火星猫"), + TagFeature::left_char_ngram(-2, "星"), + TagFeature::left_char_ngram(-2, "星猫"), + TagFeature::left_char_ngram(-2, "星猫だ"), + TagFeature::left_char_ngram(-1, "猫"), + TagFeature::left_char_ngram(-1, "猫だ"), + TagFeature::chars("だ"), + ], + tag: Arc::new("助動詞".to_string()), + }, + ]; + + expected + .iter_mut() + .for_each(|example| example.features.sort_unstable()); + expected.sort_unstable(); + + assert_eq!(expected, examples); + } + + #[test] + fn test_tag_example_generate_32() { + let gen = TagExampleGenerator::new(3, 2); + + let s = + Sentence::from_partial_annotation("A-r-i-a/名詞|は/助詞|火-星 猫|だ/助動詞").unwrap(); + let mut examples = gen.generate(&s).unwrap(); + + // The order of examples is unimportant. + examples + .iter_mut() + .for_each(|example| example.features.sort_unstable()); + examples.sort_unstable(); + + let mut expected = vec![ + TagExample { + features: vec![ + TagFeature::right_char_ngram(1, "iaは"), + TagFeature::right_char_ngram(1, "aは"), + TagFeature::right_char_ngram(1, "は"), + TagFeature::right_char_ngram(2, "aは火"), + TagFeature::right_char_ngram(2, "は火"), + TagFeature::right_char_ngram(2, "火"), + TagFeature::left_char_ngram_bos(-1, ""), + TagFeature::left_char_ngram_bos(-1, "A"), + TagFeature::left_char_ngram_bos(-1, "Ar"), + TagFeature::chars("Aria"), + ], + tag: Arc::new("名詞".to_string()), + }, + TagExample { + features: vec![ + TagFeature::right_char_ngram(1, "aは火"), + TagFeature::right_char_ngram(1, "は火"), + TagFeature::right_char_ngram(1, "火"), + TagFeature::right_char_ngram(2, "は火星"), + TagFeature::right_char_ngram(2, "火星"), + TagFeature::right_char_ngram(2, "星"), + TagFeature::left_char_ngram(-2, "i"), + TagFeature::left_char_ngram(-2, "ia"), + TagFeature::left_char_ngram(-2, "iaは"), + TagFeature::left_char_ngram(-1, "a"), + TagFeature::left_char_ngram(-1, "aは"), + TagFeature::left_char_ngram(-1, "aは火"), + TagFeature::chars("は"), + ], + tag: Arc::new("助詞".to_string()), + }, + TagExample { + features: vec![ + TagFeature::right_char_ngram_eos(1, "猫だ"), + TagFeature::right_char_ngram_eos(1, "だ"), + TagFeature::right_char_ngram_eos(1, ""), + TagFeature::left_char_ngram(-2, "星"), + TagFeature::left_char_ngram(-2, "星猫"), + TagFeature::left_char_ngram(-2, "星猫だ"), + TagFeature::left_char_ngram(-1, "猫"), + TagFeature::left_char_ngram(-1, "猫だ"), + TagFeature::chars("だ"), + ], + tag: Arc::new("助動詞".to_string()), + }, + ]; + + expected + .iter_mut() + .for_each(|example| example.features.sort_unstable()); + expected.sort_unstable(); + + assert_eq!(expected, examples); + } + + #[test] + fn test_tag_example_generate_23() { + let gen = TagExampleGenerator::new(2, 3); + + let s = + Sentence::from_partial_annotation("A-r-i-a/名詞|は/助詞|火-星 猫|だ/助動詞").unwrap(); + let mut examples = gen.generate(&s).unwrap(); + + // The order of examples is unimportant. + examples + .iter_mut() + .for_each(|example| example.features.sort_unstable()); + examples.sort_unstable(); + + let mut expected = vec![ + TagExample { + features: vec![ + TagFeature::right_char_ngram(1, "aは"), + TagFeature::right_char_ngram(1, "は"), + TagFeature::right_char_ngram(2, "は火"), + TagFeature::right_char_ngram(2, "火"), + TagFeature::right_char_ngram(3, "火星"), + TagFeature::right_char_ngram(3, "星"), + TagFeature::left_char_ngram_bos(-1, ""), + TagFeature::left_char_ngram_bos(-1, "A"), + TagFeature::chars("Aria"), + ], + tag: Arc::new("名詞".to_string()), + }, + TagExample { + features: vec![ + TagFeature::right_char_ngram(1, "は火"), + TagFeature::right_char_ngram(1, "火"), + TagFeature::right_char_ngram(2, "火星"), + TagFeature::right_char_ngram(2, "星"), + TagFeature::right_char_ngram(3, "星猫"), + TagFeature::right_char_ngram(3, "猫"), + TagFeature::left_char_ngram(-3, "r"), + TagFeature::left_char_ngram(-3, "ri"), + TagFeature::left_char_ngram(-2, "i"), + TagFeature::left_char_ngram(-2, "ia"), + TagFeature::left_char_ngram(-1, "a"), + TagFeature::left_char_ngram(-1, "aは"), + TagFeature::chars("は"), + ], + tag: Arc::new("助詞".to_string()), + }, + TagExample { + features: vec![ + TagFeature::right_char_ngram_eos(1, "だ"), + TagFeature::right_char_ngram_eos(1, ""), + TagFeature::left_char_ngram(-3, "火"), + TagFeature::left_char_ngram(-3, "火星"), + TagFeature::left_char_ngram(-2, "星"), + TagFeature::left_char_ngram(-2, "星猫"), + TagFeature::left_char_ngram(-1, "猫"), + TagFeature::left_char_ngram(-1, "猫だ"), + TagFeature::chars("だ"), + ], + tag: Arc::new("助動詞".to_string()), + }, + ]; + + expected + .iter_mut() + .for_each(|example| example.features.sort_unstable()); + expected.sort_unstable(); + + assert_eq!(expected, examples); + } + + #[test] + fn test_tag_example_generate_check_sentence_boundary() { + let gen = TagExampleGenerator::new(3, 3); + + let s = Sentence::from_tokenized("僕/代名詞 は/助詞 人間/名詞").unwrap(); + let mut examples = gen.generate(&s).unwrap(); + + // The order of examples is unimportant. + examples + .iter_mut() + .for_each(|example| example.features.sort_unstable()); + examples.sort_unstable(); + + let mut expected = vec![ + TagExample { + features: vec![ + TagFeature::right_char_ngram(1, "僕は"), + TagFeature::right_char_ngram(1, "は"), + TagFeature::right_char_ngram(2, "僕は人"), + TagFeature::right_char_ngram(2, "は人"), + TagFeature::right_char_ngram(2, "人"), + TagFeature::right_char_ngram(3, "は人間"), + TagFeature::right_char_ngram(3, "人間"), + TagFeature::right_char_ngram(3, "間"), + TagFeature::left_char_ngram_bos(-1, ""), + TagFeature::left_char_ngram_bos(-1, "僕"), + TagFeature::left_char_ngram_bos(-1, "僕は"), + TagFeature::chars("僕"), + ], + tag: Arc::new("代名詞".to_string()), + }, + TagExample { + features: vec![ + TagFeature::right_char_ngram(1, "僕は人"), + TagFeature::right_char_ngram(1, "は人"), + TagFeature::right_char_ngram(1, "人"), + TagFeature::right_char_ngram(2, "は人間"), + TagFeature::right_char_ngram(2, "人間"), + TagFeature::right_char_ngram(2, "間"), + TagFeature::right_char_ngram_eos(3, "人間"), + TagFeature::right_char_ngram_eos(3, "間"), + TagFeature::right_char_ngram_eos(3, ""), + TagFeature::left_char_ngram_bos(-2, "僕は"), + TagFeature::left_char_ngram_bos(-2, "僕"), + TagFeature::left_char_ngram_bos(-2, ""), + TagFeature::left_char_ngram(-1, "僕は人"), + TagFeature::left_char_ngram(-1, "僕は"), + TagFeature::left_char_ngram(-1, "僕"), + TagFeature::chars("は"), + ], + tag: Arc::new("助詞".to_string()), + }, + TagExample { + features: vec![ + TagFeature::right_char_ngram_eos(1, "人間"), + TagFeature::right_char_ngram_eos(1, "間"), + TagFeature::right_char_ngram_eos(1, ""), + TagFeature::left_char_ngram_bos(-3, "僕は"), + TagFeature::left_char_ngram_bos(-3, "僕"), + TagFeature::left_char_ngram_bos(-3, ""), + TagFeature::left_char_ngram(-2, "僕は人"), + TagFeature::left_char_ngram(-2, "僕は"), + TagFeature::left_char_ngram(-2, "僕"), + TagFeature::left_char_ngram(-1, "は人間"), + TagFeature::left_char_ngram(-1, "は人"), + TagFeature::left_char_ngram(-1, "は"), + TagFeature::chars("人間"), + ], + tag: Arc::new("名詞".to_string()), + }, + ]; + + expected + .iter_mut() + .for_each(|example| example.features.sort_unstable()); + expected.sort_unstable(); + + assert_eq!(expected, examples); + } } diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 78b407c9..67ee42d7 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -1,10 +1,12 @@ use std::convert::TryFrom; use std::io::BufRead; -use anyhow::{anyhow, Result}; -use byteorder::{LittleEndian, ReadBytesExt}; - -use crate::model::{DictWeight, Model}; +use crate::dict_model::{DictModel, DictWeight, WordWeightRecord}; +use crate::errors::{Result, VaporettoError}; +use crate::model::Model; +use crate::ngram_model::{NgramData, NgramModel}; +use crate::tag_model::TagModel; +use crate::utils; struct KyteaConfig { _model_tag: String, @@ -23,20 +25,23 @@ struct KyteaConfig { } impl KyteaConfig { - fn read(rdr: &mut R) -> Result { + fn read(mut rdr: R) -> Result + where + R: BufRead, + { let mut model_tag = String::new(); rdr.read_line(&mut model_tag)?; - let do_ws = rdr.read_u8()? != 0; - let do_tags = rdr.read_u8()? != 0; - let n_tags = rdr.read_u32::()?; - let char_w = rdr.read_u8()?; - let char_n = rdr.read_u8()?; - let type_w = rdr.read_u8()?; - let type_n = rdr.read_u8()?; - let dict_n = rdr.read_u8()?; - let bias = rdr.read_u8()? != 0; - let epsilon = rdr.read_f64::()?; - let solver_type = rdr.read_u8()?; + let do_ws = utils::read_u8(&mut rdr)? != 0; + let do_tags = utils::read_u8(&mut rdr)? != 0; + let n_tags = utils::read_u32(&mut rdr)?; + let char_w = utils::read_u8(&mut rdr)?; + let char_n = utils::read_u8(&mut rdr)?; + let type_w = utils::read_u8(&mut rdr)?; + let type_n = utils::read_u8(&mut rdr)?; + let dict_n = utils::read_u8(&mut rdr)?; + let bias = utils::read_u8(&mut rdr)? != 0; + let epsilon = utils::read_f64(&mut rdr)?; + let solver_type = utils::read_u8(&mut rdr)?; let mut char_map = vec![]; rdr.read_until(0, &mut char_map)?; let char_map: Vec = String::from_utf8(char_map)?.chars().collect(); @@ -59,24 +64,35 @@ impl KyteaConfig { } trait Readable: Sized { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result; + fn read(config: &KyteaConfig, rdr: R) -> Result + where + R: BufRead; } impl Readable for i16 { - fn read(_config: &KyteaConfig, rdr: &mut R) -> Result { - Ok(rdr.read_i16::()?) + fn read(_config: &KyteaConfig, mut rdr: R) -> Result + where + R: BufRead, + { + Ok(utils::read_i16(&mut rdr)?) } } impl Readable for f64 { - fn read(_config: &KyteaConfig, rdr: &mut R) -> Result { - Ok(rdr.read_f64::()?) + fn read(_config: &KyteaConfig, mut rdr: R) -> Result + where + R: BufRead, + { + Ok(utils::read_f64(&mut rdr)?) } } impl Readable for char { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result { - let cidx = rdr.read_u16::()? as usize; + fn read(config: &KyteaConfig, mut rdr: R) -> Result + where + R: BufRead, + { + let cidx = utils::read_u16(&mut rdr)? as usize; Ok(config.char_map[cidx - 1]) } } @@ -85,22 +101,28 @@ impl Readable for Vec where T: Readable, { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result { - let size = rdr.read_u32::()?; + fn read(config: &KyteaConfig, mut rdr: R) -> Result + where + R: BufRead, + { + let size = utils::read_u32(&mut rdr)?; let mut result = Self::with_capacity(size as usize); for _ in 0..size { - result.push(T::read(config, rdr)?); + result.push(T::read(config, &mut rdr)?); } Ok(result) } } impl Readable for String { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result { - let size = rdr.read_u32::()?; + fn read(config: &KyteaConfig, mut rdr: R) -> Result + where + R: BufRead, + { + let size = utils::read_u32(&mut rdr)?; let mut result = Self::new(); for _ in 0..size { - let cidx = rdr.read_u16::()? as usize; + let cidx = utils::read_u16(&mut rdr)? as usize; result.push(config.char_map[cidx - 1]); } Ok(result) @@ -149,29 +171,32 @@ impl Dictionary where T: Readable, { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result> { - let n_dicts = rdr.read_u8()?; - let n_states = rdr.read_u32::()? as usize; + fn read(config: &KyteaConfig, mut rdr: R) -> Result> + where + R: BufRead, + { + let n_dicts = utils::read_u8(&mut rdr)?; + let n_states = utils::read_u32(&mut rdr)? as usize; if n_states == 0 { return Ok(None); } let mut states = Vec::with_capacity(n_states); for _ in 0..n_states { - let failure = rdr.read_u32::()?; - let n_gotos = rdr.read_u32::()?; + let failure = utils::read_u32(&mut rdr)?; + let n_gotos = utils::read_u32(&mut rdr)?; let mut gotos = vec![]; for _ in 0..n_gotos { - let k = char::read(config, rdr)?; - let v = rdr.read_u32::()?; + let k = char::read(config, &mut rdr)?; + let v = utils::read_u32(&mut rdr)?; gotos.push((k, v)); } gotos.sort_unstable(); - let n_outputs = rdr.read_u32::()? as usize; + let n_outputs = utils::read_u32(&mut rdr)? as usize; let mut outputs = Vec::with_capacity(n_outputs); for _ in 0..n_outputs { - outputs.push(rdr.read_u32::()?); + outputs.push(utils::read_u32(&mut rdr)?); } - let is_branch = rdr.read_u8()? != 0; + let is_branch = utils::read_u8(&mut rdr)? != 0; states.push(State { _failure: failure, gotos, @@ -179,10 +204,10 @@ where is_branch, }); } - let n_entries = rdr.read_u32::()? as usize; + let n_entries = utils::read_u32(&mut rdr)? as usize; let mut entries = Vec::with_capacity(n_entries); for _ in 0..n_entries { - entries.push(T::read(config, rdr)?); + entries.push(T::read(config, &mut rdr)?); } Ok(Some(Self { n_dicts, @@ -209,18 +234,21 @@ impl FeatureLookup where T: Readable, { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result> { - let active = rdr.read_u8()?; + fn read(config: &KyteaConfig, mut rdr: R) -> Result> + where + R: BufRead, + { + let active = utils::read_u8(&mut rdr)?; if active == 0 { return Ok(None); } - let char_dict = Dictionary::read(config, rdr)?; - let type_dict = Dictionary::read(config, rdr)?; - let self_dict = Dictionary::read(config, rdr)?; - let dict_vec = Vec::::read(config, rdr)?; - let biases = Vec::::read(config, rdr)?; - let tag_dict_vec = Vec::::read(config, rdr)?; - let tag_unk_vec = Vec::::read(config, rdr)?; + let char_dict = Dictionary::read(config, &mut rdr)?; + let type_dict = Dictionary::read(config, &mut rdr)?; + let self_dict = Dictionary::read(config, &mut rdr)?; + let dict_vec = Vec::::read(config, &mut rdr)?; + let biases = Vec::::read(config, &mut rdr)?; + let tag_dict_vec = Vec::::read(config, &mut rdr)?; + let tag_unk_vec = Vec::::read(config, &mut rdr)?; Ok(Some(Self { char_dict, type_dict, @@ -238,31 +266,34 @@ struct LinearModel { _solver_type: u8, _labels: Vec, _bias: bool, - multiplier: f64, + _multiplier: f64, feature_lookup: Option>, } impl Readable for Option { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result { - let n_classes = rdr.read_u32::()?; + fn read(config: &KyteaConfig, mut rdr: R) -> Result + where + R: BufRead, + { + let n_classes = utils::read_u32(&mut rdr)?; if n_classes == 0 { return Ok(None); } let add_features = false; - let solver_type = rdr.read_u8()?; + let solver_type = utils::read_u8(&mut rdr)?; let mut labels = vec![]; for _ in 0..n_classes { - labels.push(rdr.read_i32::()?); + labels.push(utils::read_i32(&mut rdr)?); } - let bias = rdr.read_u8()? != 0; - let multiplier = rdr.read_f64::()?; - let feature_lookup = FeatureLookup::read(config, rdr)?; + let bias = utils::read_u8(&mut rdr)? != 0; + let multiplier = utils::read_f64(&mut rdr)?; + let feature_lookup = FeatureLookup::read(config, &mut rdr)?; Ok(Some(LinearModel { _add_features: add_features, _solver_type: solver_type, _labels: labels, _bias: bias, - multiplier, + _multiplier: multiplier, feature_lookup, })) } @@ -277,25 +308,28 @@ struct ModelTagEntry { } impl Readable for ModelTagEntry { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result { - let word = String::read(config, rdr)?; + fn read(config: &KyteaConfig, mut rdr: R) -> Result + where + R: BufRead, + { + let word = String::read(config, &mut rdr)?; let mut tags = Vec::with_capacity(config.n_tags as usize); let mut tags_in_dicts = Vec::with_capacity(config.n_tags as usize); for _ in 0..config.n_tags { - let size = rdr.read_u32::()? as usize; + let size = utils::read_u32(&mut rdr)? as usize; let mut t = Vec::with_capacity(size); let mut td = Vec::with_capacity(size); for _ in 0..size { - t.push(String::read(config, rdr)?); - td.push(rdr.read_u8()?); + t.push(String::read(config, &mut rdr)?); + td.push(utils::read_u8(&mut rdr)?); } tags.push(t); tags_in_dicts.push(td); } - let in_dict = rdr.read_u8()?; + let in_dict = utils::read_u8(&mut rdr)?; let mut tag_models = Vec::with_capacity(config.n_tags as usize); for _ in 0..config.n_tags { - tag_models.push(Option::::read(config, rdr)?); + tag_models.push(Option::::read(config, &mut rdr)?); } Ok(Self { _word: word, @@ -314,17 +348,20 @@ struct ProbTagEntry { } impl Readable for ProbTagEntry { - fn read(config: &KyteaConfig, rdr: &mut R) -> Result { - let word = String::read(config, rdr)?; + fn read(config: &KyteaConfig, mut rdr: R) -> Result + where + R: BufRead, + { + let word = String::read(config, &mut rdr)?; let mut tags = Vec::with_capacity(config.n_tags as usize); let mut probs = Vec::with_capacity(config.n_tags as usize); for _ in 0..config.n_tags { - let size = rdr.read_u32::()? as usize; + let size = utils::read_u32(&mut rdr)? as usize; let mut t = Vec::with_capacity(size); let mut p = Vec::with_capacity(size); for _ in 0..size { - t.push(String::read(config, rdr)?); - p.push(rdr.read_f64::()?); + t.push(String::read(config, &mut rdr)?); + p.push(utils::read_f64(&mut rdr)?); } tags.push(t); probs.push(p); @@ -362,21 +399,24 @@ impl KyteaModel { /// # Errors /// /// When `rdr` generates an error, it will be returned as is. - pub fn read(rdr: &mut R) -> Result { - let config = KyteaConfig::read(rdr)?; + pub fn read(mut rdr: R) -> Result + where + R: BufRead, + { + let config = KyteaConfig::read(&mut rdr)?; - let wordseg_model = Option::::read(&config, rdr)?; + let wordseg_model = Option::::read(&config, &mut rdr)?; let mut global_tags = Vec::with_capacity(config.n_tags as usize); let mut global_models = Vec::with_capacity(config.n_tags as usize); for _ in 0..config.n_tags { - global_tags.push(Vec::::read(&config, rdr)?); - global_models.push(Option::::read(&config, rdr)?); + global_tags.push(Vec::::read(&config, &mut rdr)?); + global_models.push(Option::::read(&config, &mut rdr)?); } - let dict = Dictionary::::read(&config, rdr)?; - let subword_dict = Dictionary::::read(&config, rdr)?; + let dict = Dictionary::::read(&config, &mut rdr)?; + let subword_dict = Dictionary::::read(&config, &mut rdr)?; Ok(Self { config, @@ -390,43 +430,47 @@ impl KyteaModel { } impl TryFrom for Model { - type Error = anyhow::Error; + type Error = VaporettoError; fn try_from(model: KyteaModel) -> Result { let config = &model.config; let wordseg_model = model .wordseg_model - .ok_or_else(|| anyhow!("no word segmentation model."))?; - let quantize_multiplier = wordseg_model.multiplier; + .ok_or_else(|| VaporettoError::invalid_model("no word segmentation model."))?; let feature_lookup = wordseg_model .feature_lookup - .ok_or_else(|| anyhow!("no lookup data."))?; - let bias = feature_lookup.biases[0]; + .ok_or_else(|| VaporettoError::invalid_model("no lookup data."))?; + let bias = feature_lookup.biases[0] as i32; let char_dict = feature_lookup .char_dict - .ok_or_else(|| anyhow!("no character dictionary."))?; + .ok_or_else(|| VaporettoError::invalid_model("no character dictionary."))?; let type_dict = feature_lookup .type_dict - .ok_or_else(|| anyhow!("no type dictionary."))?; - - let mut words: Vec> = vec![]; - let mut word_weights = vec![]; - for (word, v) in char_dict.dump_items() { - let weight_size = config.char_w as usize * 2 - word.len() + 1; - words.push(word.into_iter().collect::().as_bytes().to_vec()); - word_weights.push(v[..weight_size].to_vec()); + .ok_or_else(|| VaporettoError::invalid_model("no type dictionary."))?; + + let mut char_ngrams = vec![]; + for (char_ngram, v) in char_dict.dump_items() { + let weight_size = config.char_w as usize * 2 - char_ngram.len() + 1; + char_ngrams.push(NgramData { + ngram: char_ngram.into_iter().collect(), + weights: v[..weight_size].iter().map(|&w| w as i32).collect(), + }); } - let mut types: Vec> = vec![]; - let mut type_weights = vec![]; - for (word, v) in type_dict.dump_items() { - let weight_size = config.type_w as usize * 2 - word.len() + 1; - types.push(word.into_iter().collect::().as_bytes().to_vec()); - type_weights.push(v[..weight_size].to_vec()); + let mut type_ngrams = vec![]; + for (type_ngram, v) in type_dict.dump_items() { + let weight_size = config.type_w as usize * 2 - type_ngram.len() + 1; + type_ngrams.push(NgramData { + ngram: type_ngram + .into_iter() + .collect::() + .as_bytes() + .to_vec(), + weights: v[..weight_size].iter().map(|&w| w as i32).collect(), + }); } - let mut dict: Vec> = vec![]; - let mut dict_weights = vec![]; + let mut dict = vec![]; if let Some(kytea_dict) = model.dict { for (w, data) in kytea_dict.dump_items() { let word_len = std::cmp::min(w.len(), config.dict_n as usize) - 1; @@ -435,28 +479,24 @@ impl TryFrom for Model { if data.in_dict >> j & 1 == 1 { let offset = 3 * config.dict_n as usize * j + 3 * word_len; weights.right += feature_lookup.dict_vec[offset] as i32; - weights.inner += feature_lookup.dict_vec[offset + 1] as i32; + weights.inside += feature_lookup.dict_vec[offset + 1] as i32; weights.left += feature_lookup.dict_vec[offset + 2] as i32; } } - dict_weights.push(weights); - dict.push(w.into_iter().collect::().as_bytes().to_vec()); + dict.push(WordWeightRecord { + word: w.into_iter().collect(), + weights, + comment: "".to_string(), + }); } } Ok(Self { - words, - types, - dict, - - #[cfg(feature = "model-quantize")] - quantize_multiplier, - - word_weights, - type_weights, - dict_weights, - dict_word_wise: true, + char_ngram_model: NgramModel::new(char_ngrams), + type_ngram_model: NgramModel::new(type_ngrams), + dict_model: DictModel::new(dict), bias, + tag_model: TagModel::default(), char_window_size: config.char_w as usize, type_window_size: config.type_w as usize, }) diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index d7107e66..8985eff0 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -1,4 +1,5 @@ #![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(feature = "portable-simd", feature(portable_simd))] //! # Vaporetto //! @@ -14,43 +15,45 @@ //! //! let mut f = BufReader::new(File::open("model.bin").unwrap()); //! let model = Model::read(&mut f).unwrap(); -//! let predictor = Predictor::new(model); -//! -//! for line in stdin().lock().lines() { -//! let s = Sentence::from_raw(line.unwrap()).unwrap(); -//! let s = predictor.predict(s); -//! let toks = s.to_tokenized_string().unwrap(); -//! println!("{}", toks); -//! } +//! let predictor = Predictor::new(model, false).unwrap(); +//! +//! let s = Sentence::from_raw("火星猫の生態").unwrap(); +//! let s = predictor.predict(s); +//! +//! println!("{:?}", s.to_tokenized_vec().unwrap()); //! ``` //! //! Training requires **crate feature** `train`. For more details, see [`Trainer`]. -#[macro_use] -mod utils; - +mod char_scorer; +mod dict_model; mod model; +mod ngram_model; mod predictor; mod sentence; +mod tag_model; mod type_scorer; +mod utils; + +pub mod errors; #[cfg(feature = "train")] mod feature; #[cfg(feature = "train")] +mod tag_trainer; +#[cfg(feature = "train")] mod trainer; #[cfg(feature = "kytea")] mod kytea_model; +pub use dict_model::WordWeightRecord; pub use model::Model; pub use predictor::Predictor; -pub use sentence::{BoundaryType, CharacterType, Sentence}; - -#[cfg(feature = "multithreading")] -pub use predictor::MultithreadPredictor; +pub use sentence::{BoundaryType, CharacterType, Sentence, Token}; #[cfg(feature = "train")] -pub use trainer::{Dataset, SolverType, Trainer}; +pub use trainer::{SolverType, Trainer}; #[cfg(feature = "kytea")] pub use kytea_model::KyteaModel; diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index b2465346..f31505cf 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -1,54 +1,20 @@ use std::io::{Read, Write}; -use anyhow::Result; -use serde::{Deserialize, Serialize}; - -#[cfg(feature = "train")] -use crate::feature::FeatureContent; -#[cfg(feature = "train")] -use crate::sentence::BoundaryType; -#[cfg(feature = "train")] -use crate::utils::{FeatureIDManager, StringIdManager}; -#[cfg(feature = "train")] -use liblinear::LibLinearModel; -#[cfg(feature = "train")] -const EPSILON: f64 = 1e-6; - -#[cfg(not(feature = "model-quantize"))] -pub type WeightValue = f64; -#[cfg(feature = "model-quantize")] -pub type WeightValue = i16; -#[cfg(not(feature = "model-quantize"))] -pub type ScoreValue = f64; -#[cfg(feature = "model-quantize")] -pub type ScoreValue = i32; - -#[derive(Clone, Copy, Default, Serialize, Deserialize)] -pub struct DictWeight { - pub right: ScoreValue, - pub inner: ScoreValue, - pub left: ScoreValue, -} +use crate::dict_model::{DictModel, WordWeightRecord}; +use crate::errors::Result; +use crate::ngram_model::NgramModel; +use crate::tag_model::TagModel; +use crate::utils; /// Model data. -#[derive(Serialize, Deserialize)] pub struct Model { - pub(crate) words: Vec>, - pub(crate) types: Vec>, - pub(crate) dict: Vec>, - - pub(crate) word_weights: Vec>, - pub(crate) type_weights: Vec>, - pub(crate) dict_weights: Vec, - - #[cfg(feature = "model-quantize")] - pub(crate) quantize_multiplier: f64, - - pub(crate) dict_word_wise: bool, - - pub(crate) bias: WeightValue, + pub(crate) char_ngram_model: NgramModel, + pub(crate) type_ngram_model: NgramModel>, + pub(crate) dict_model: DictModel, + pub(crate) bias: i32, pub(crate) char_window_size: usize, pub(crate) type_window_size: usize, + pub(crate) tag_model: TagModel, } impl Model { @@ -61,11 +27,17 @@ impl Model { /// # Errors /// /// When `wtr` generates an error, it will be returned as is. - pub fn write(&self, wtr: &mut W) -> Result<()> + pub fn write(&self, mut wtr: W) -> Result<()> where W: Write, { - bincode::serialize_into(wtr, self)?; + self.char_ngram_model.serialize(&mut wtr)?; + self.type_ngram_model.serialize(&mut wtr)?; + self.dict_model.serialize(&mut wtr)?; + utils::write_i32(&mut wtr, self.bias)?; + utils::write_u32(&mut wtr, self.char_window_size.try_into().unwrap())?; + utils::write_u32(&mut wtr, self.type_window_size.try_into().unwrap())?; + self.tag_model.serialize(&mut wtr)?; Ok(()) } @@ -82,107 +54,26 @@ impl Model { /// # Errors /// /// When `rdr` generates an error, it will be returned as is. - pub fn read(rdr: &mut R) -> Result + pub fn read(mut rdr: R) -> Result where R: Read, { - Ok(bincode::deserialize_from(rdr)?) + Ok(Self { + char_ngram_model: NgramModel::::deserialize(&mut rdr)?, + type_ngram_model: NgramModel::>::deserialize(&mut rdr)?, + dict_model: DictModel::deserialize(&mut rdr)?, + bias: utils::read_i32(&mut rdr)?, + char_window_size: utils::read_u32(&mut rdr)?.try_into().unwrap(), + type_window_size: utils::read_u32(&mut rdr)?.try_into().unwrap(), + tag_model: TagModel::deserialize(&mut rdr)?, + }) } - #[cfg(feature = "train")] - pub(crate) fn from_liblinear_model( - model: impl LibLinearModel, - fid_manager: FeatureIDManager, - dict: Vec>, - char_window_size: usize, - type_window_size: usize, - dict_word_max_size: usize, - ) -> Self { - let wb_idx = model - .labels() - .iter() - .position(|&cls| BoundaryType::WordBoundary as i32 == cls) - .unwrap() as i32; - - let bias = model.label_bias(wb_idx); - let mut words = vec![]; - let mut types = vec![]; - let mut word_weights = vec![]; - let mut type_weights = vec![]; - let mut dict_weights = vec![DictWeight::default(); dict_word_max_size]; - let mut word_ids = StringIdManager::new(); - let mut type_ids = StringIdManager::new(); - - #[cfg(feature = "model-quantize")] - let quantize_multiplier = { - let mut weight_max = bias.abs(); - for fid in 0..model.num_features() { - let weight = model.feature_coefficient(fid as i32, wb_idx).abs(); - if weight > weight_max { - weight_max = weight; - } - } - weight_max / 32767. - }; - - #[cfg(feature = "model-quantize")] - let bias = (bias / quantize_multiplier) as i16; - - for (feature, fid) in fid_manager.map { - let weight = model.feature_coefficient(fid as i32 + 1, wb_idx); - if weight > -EPSILON && weight < EPSILON { - continue; - } - - #[cfg(feature = "model-quantize")] - let weight = weight / quantize_multiplier; - - match feature.feature { - FeatureContent::CharacterNgram(word) => { - let id = word_ids.get_id(word.as_bytes()); - if id == word_weights.len() { - words.push(word.as_bytes().to_vec()); - word_weights.push(vec![ - WeightValue::default(); - char_window_size * 2 - word.chars().count() + 1 - ]); - } - word_weights[id][feature.rel_position] = weight as WeightValue; - } - FeatureContent::CharacterTypeNgram(word) => { - let id = type_ids.get_id(word) as usize; - if id == type_weights.len() { - types.push(word.to_vec()); - type_weights.push(vec![ - WeightValue::default(); - type_window_size * 2 - word.len() + 1 - ]); - } - type_weights[id][feature.rel_position] = weight as WeightValue; - } - FeatureContent::DictionaryWord(size) => match feature.rel_position { - 0 => dict_weights[size - 1].right = weight as ScoreValue, - 1 => dict_weights[size - 1].inner = weight as ScoreValue, - 2 => dict_weights[size - 1].left = weight as ScoreValue, - _ => panic!("Invalid rel_position"), - }, - }; - } - Self { - words, - types, - dict, - - #[cfg(feature = "model-quantize")] - quantize_multiplier, + pub fn dictionary(&self) -> &[WordWeightRecord] { + self.dict_model.dictionary() + } - word_weights, - type_weights, - dict_weights, - dict_word_wise: false, - bias, - char_window_size, - type_window_size, - } + pub fn replace_dictionary(&mut self, dict: Vec) { + self.dict_model = DictModel::new(dict); } } diff --git a/vaporetto/src/ngram_model.rs b/vaporetto/src/ngram_model.rs new file mode 100644 index 00000000..d79b287d --- /dev/null +++ b/vaporetto/src/ngram_model.rs @@ -0,0 +1,129 @@ +use std::io::{Read, Write}; +use std::mem; + +use crate::errors::Result; +use crate::utils; + +#[derive(Clone)] +pub struct NgramData +where + T: Clone, +{ + pub(crate) ngram: T, + pub(crate) weights: Vec, +} + +impl NgramData +where + T: AsRef<[u8]> + Clone, +{ + pub fn serialize(&self, mut wtr: W) -> Result + where + W: Write, + { + let ngram = self.ngram.as_ref(); + let ngram_size = ngram.len(); + let weights_size = self.weights.len(); + utils::write_u32(&mut wtr, ngram_size.try_into().unwrap())?; + utils::write_u32(&mut wtr, weights_size.try_into().unwrap())?; + wtr.write_all(ngram)?; + for &w in &self.weights { + utils::write_i32(&mut wtr, w)?; + } + Ok(mem::size_of::() * 2 + ngram_size + mem::size_of::() * weights_size) + } +} + +impl NgramData { + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + let ngram_size = utils::read_u32(&mut rdr)?; + let weights_size = utils::read_u32(&mut rdr)?; + let mut ngram_bytes = vec![0; ngram_size.try_into().unwrap()]; + rdr.read_exact(&mut ngram_bytes)?; + let ngram = String::from_utf8(ngram_bytes)?; + let mut weights = vec![]; + for _ in 0..weights_size { + weights.push(utils::read_i32(&mut rdr)?); + } + Ok(Self { ngram, weights }) + } +} + +impl NgramData> { + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + let ngram_size = utils::read_u32(&mut rdr)?; + let weights_size = utils::read_u32(&mut rdr)?; + let mut ngram = vec![0; ngram_size.try_into().unwrap()]; + rdr.read_exact(&mut ngram)?; + let mut weights = Vec::with_capacity(weights_size.try_into().unwrap()); + for _ in 0..weights_size { + weights.push(utils::read_i32(&mut rdr)?); + } + Ok(Self { ngram, weights }) + } +} + +#[derive(Default)] +pub struct NgramModel +where + T: Clone, +{ + pub(crate) data: Vec>, +} + +impl NgramModel +where + T: AsRef<[u8]> + Clone, +{ + #[cfg(any(feature = "train", feature = "kytea", test))] + pub fn new(data: Vec>) -> Self { + Self { data } + } + + pub fn serialize(&self, mut wtr: W) -> Result + where + W: Write, + { + let data_size = self.data.len(); + utils::write_u32(&mut wtr, data_size.try_into().unwrap())?; + let mut total_size = mem::size_of::(); + for d in &self.data { + total_size += d.serialize(&mut wtr)?; + } + Ok(total_size + mem::size_of::()) + } +} + +impl NgramModel { + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + let data_size = utils::read_u32(&mut rdr)?; + let mut data = Vec::with_capacity(data_size.try_into().unwrap()); + for _ in 0..data_size { + data.push(NgramData::::deserialize(&mut rdr)?); + } + Ok(Self { data }) + } +} + +impl NgramModel> { + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + let data_size = utils::read_u32(&mut rdr)?; + let mut data = Vec::with_capacity(data_size.try_into().unwrap()); + for _ in 0..data_size { + data.push(NgramData::>::deserialize(&mut rdr)?); + } + Ok(Self { data }) + } +} diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index a2a6f245..2fde23c5 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -1,37 +1,31 @@ -use std::collections::HashMap; -use std::ops::Range; +use std::mem; -#[cfg(feature = "multithreading")] -use std::cell::RefCell; -#[cfg(feature = "multithreading")] +use std::cmp::Ordering; use std::sync::Arc; -#[cfg(feature = "multithreading")] -use std::thread; -#[cfg(feature = "multithreading")] -use crossbeam_channel::{Receiver, Sender}; - -use crate::model::{DictWeight, Model, ScoreValue}; +use crate::char_scorer::{self, CharScorer, CharScorerWithTags}; +use crate::errors::Result; +use crate::model::Model; use crate::sentence::{BoundaryType, Sentence}; use crate::type_scorer::TypeScorer; -use daachorse::DoubleArrayAhoCorasick; +enum CharScorerWrapper { + Boundary(CharScorer), + BoundaryAndTags(CharScorerWithTags), +} /// Predictor. pub struct Predictor { - word_pma: DoubleArrayAhoCorasick, - dict_pma: DoubleArrayAhoCorasick, - word_weights: Vec>, - dict_weights: Vec, - dict_word_wise: bool, - bias: ScoreValue, - char_window_size: usize, - dict_window_size: usize, + bias: i32, + char_scorer: CharScorerWrapper, type_scorer: TypeScorer, - #[cfg(feature = "model-quantize")] - quantize_multiplier: f64, + padding: usize, + + // for tag prediction + tag_names: Vec>, + tag_bias: Vec, } impl Predictor { @@ -40,291 +34,80 @@ impl Predictor { /// # Arguments /// /// * `model` - A model data. + /// * `predict_tags` - If you want to predict tags, set to true. /// /// # Returns /// /// A new predictor. - pub fn new(model: Model) -> Self { - let bias = model.bias; + pub fn new(model: Model, predict_tags: bool) -> Result { + let mut tag_names = vec![]; + let mut tag_bias = vec![]; - let words = model.words; - let dict = model.dict; - let dict_weights = model.dict_weights; - - let mut word_weights: Vec<_> = model - .word_weights - .into_iter() - .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) - .collect(); - let type_weights: Vec<_> = model - .type_weights - .into_iter() - .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) - .collect(); - - let (dict, dict_weights) = Self::merge_dict_weights( - dict, - dict_weights, - &words, - &mut word_weights, - model.char_window_size, - model.dict_word_wise, - ); - - let word_weights = Self::merge_weights(&words, &word_weights); - let type_weights = Self::merge_weights(&model.types, &type_weights); - - #[cfg(feature = "model-quantize")] - let bias = bias as i32; - - let word_pma = DoubleArrayAhoCorasick::new(words).unwrap(); - let type_pma = DoubleArrayAhoCorasick::new(model.types).unwrap(); - let dict_pma = DoubleArrayAhoCorasick::new(dict).unwrap(); - - let type_scorer = TypeScorer::new(type_pma, type_weights, model.type_window_size); + let char_scorer = if predict_tags { + for cls in model.tag_model.class_info { + tag_names.push(Arc::new(cls.name)); + tag_bias.push(cls.bias); + } + CharScorerWrapper::BoundaryAndTags(CharScorerWithTags::new( + model.char_ngram_model, + model.char_window_size, + model.dict_model, + tag_names.len(), + model.tag_model.left_char_model, + model.tag_model.right_char_model, + model.tag_model.self_char_model, + )?) + } else { + CharScorerWrapper::Boundary(CharScorer::new( + model.char_ngram_model, + model.char_window_size, + model.dict_model, + )?) + }; + let type_scorer = TypeScorer::new(model.type_ngram_model, model.type_window_size)?; - Self { - word_pma, - dict_pma, - word_weights, - dict_weights, - dict_word_wise: model.dict_word_wise, - bias, - char_window_size: model.char_window_size, - dict_window_size: 1, + Ok(Self { + bias: model.bias, + char_scorer, type_scorer, - #[cfg(feature = "model-quantize")] - quantize_multiplier: model.quantize_multiplier, - } - } + padding: model.char_window_size.max(model.type_window_size), - fn merge_dict_weights( - dict: Vec>, - dict_weights: Vec, - words: &[Vec], - word_weights: &mut Vec>, - char_window_size: usize, - dict_word_wise: bool, - ) -> (Vec>, Vec) { - let mut word_map = HashMap::new(); - for (i, word) in words.iter().cloned().enumerate() { - word_map.insert(word, i); - } - let mut new_dict = vec![]; - if dict_word_wise { - let mut new_dict_weights = vec![]; - for (word, weight) in dict.into_iter().zip(dict_weights) { - let word_size = std::str::from_utf8(&word).unwrap().chars().count(); - match word_map.get(&word) { - Some(&idx) if char_window_size >= word_size => { - let start = char_window_size - word_size; - let end = start + word_size; - word_weights[idx][start] += weight.right; - for i in start + 1..end { - word_weights[idx][i] += weight.inner; - } - word_weights[idx][end] += weight.left; - } - _ => { - new_dict.push(word); - new_dict_weights.push(weight); - } - } - } - (new_dict, new_dict_weights) - } else { - for word in dict { - let word_size = std::str::from_utf8(&word).unwrap().chars().count(); - match word_map.get(&word) { - Some(&idx) if char_window_size >= word_size => { - let start = char_window_size - word_size; - let end = start + word_size; - let word_size_idx = std::cmp::min(word_size, dict_weights.len()) - 1; - let weight = &dict_weights[word_size_idx]; - word_weights[idx][start] += weight.right; - for i in start + 1..end { - word_weights[idx][i] += weight.inner; - } - word_weights[idx][end] += weight.left; - } - _ => new_dict.push(word), - } - } - (new_dict, dict_weights) - } + tag_names, + tag_bias, + }) } - fn merge_weights(words: &[Vec], weights: &[Vec]) -> Vec> { - let mut result = vec![]; - let word_ids = words - .iter() - .cloned() - .enumerate() - .map(|(i, w)| (w, i)) - .collect::, usize>>(); - for seq in words { - let mut new_weights: Option> = None; - for st in (0..seq.len()).rev() { - if let Some(&idx) = word_ids.get(&seq[st..]) { - if let Some(new_weights) = new_weights.as_mut() { - for (w_new, w) in new_weights.iter_mut().zip(&weights[idx]) { - *w_new += *w; - } - } else { - new_weights.replace(weights[idx].clone()); - } - } + fn predict_impl(&self, mut sentence: Sentence) -> Sentence { + let ys_size = sentence.boundaries.len() + self.padding + char_scorer::SIMD_SIZE - 1; + let mut ys = mem::take(&mut sentence.boundary_scores); + ys.clear(); + ys.resize(ys_size, self.bias); + match &self.char_scorer { + CharScorerWrapper::Boundary(char_scorer) => { + char_scorer.add_scores(&sentence, self.padding, &mut ys); } - result.push(new_weights.unwrap()); - } - result - } - - fn add_word_ngram_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let char_start = if start >= self.char_window_size { - start + 1 - self.char_window_size - } else { - 0 - }; - let text_start = sentence.char_to_str_pos[char_start]; - let char_end = std::cmp::min( - start + ys.len() + self.char_window_size, - sentence.char_to_str_pos.len() - 1, - ); - let text_end = sentence.char_to_str_pos[char_end]; - let text = &sentence.text[text_start..text_end]; - let padding = start - char_start + 1; - for m in self.word_pma.find_overlapping_no_suffix_iter(&text) { - let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; - let offset = m_end as isize - self.char_window_size as isize - padding as isize; - let weights = &self.word_weights[m.pattern()]; - if offset >= 0 { - for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) { - *y += w; - } - } else { - for (w, y) in weights[-offset as usize..].iter().zip(ys.iter_mut()) { - *y += w; - } + CharScorerWrapper::BoundaryAndTags(char_scorer) => { + let mut tag_ys = mem::take(&mut sentence.tag_scores); + tag_ys.init(sentence.chars.len(), self.tag_names.len()); + char_scorer.add_scores(&sentence, self.padding, &mut ys, &mut tag_ys); + sentence.tag_scores = tag_ys; } } - } - - fn add_dict_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let char_start = if start >= self.dict_window_size { - start + 1 - self.dict_window_size - } else { - 0 - }; - let text_start = sentence.char_to_str_pos[char_start]; - let char_end = std::cmp::min( - start + ys.len() + self.dict_window_size, - sentence.char_to_str_pos.len() - 1, - ); - let text_end = sentence.char_to_str_pos[char_end]; - let text = &sentence.text[text_start..text_end]; - let padding = start - char_start + 1; - for m in self.dict_pma.find_overlapping_iter(&text) { - let m_start = sentence.str_to_char_pos[m.start() + text_start] - char_start; - let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; - let idx = if self.dict_word_wise { - m.pattern() - } else { - std::cmp::min(m_end - m_start, self.dict_weights.len()) - 1 - }; - let dict_weight = self.dict_weights[idx]; - if m_start >= padding && m_start < padding + ys.len() { - ys[m_start - padding] += dict_weight.right; - } - let range_start = std::cmp::max(0, m_start as isize - padding as isize + 1); - let range_end = std::cmp::min(m_end as isize - padding as isize, ys.len() as isize); - if range_start < range_end { - for y in &mut ys[range_start as usize..range_end as usize] { - *y += dict_weight.inner; - } - } - if m_end >= padding && m_end < ys.len() + padding { - ys[m_end - padding] += dict_weight.left; - } - } - } - - fn predict_partial_impl( - &self, - sentence: &Sentence, - range: Range, - ys: &mut [ScoreValue], - ) { - ys.fill(self.bias); - self.add_word_ngram_scores(sentence, range.start, ys); - self.type_scorer.add_scores(sentence, range.start, ys); - self.add_dict_scores(sentence, range.start, ys); - } - - /// Predicts word boundaries of the specified range of a sentence. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// * `range` - The range of the sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict_partial(&self, mut sentence: Sentence, range: Range) -> Sentence { - let mut ys = vec![ScoreValue::default(); range.len()]; - self.predict_partial_impl(&sentence, range.clone(), &mut ys); - for (y, b) in ys.into_iter().zip(sentence.boundaries[range].iter_mut()) { - *b = if y >= ScoreValue::default() { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }; - } - sentence - } - - /// Predicts word boundaries of the specified range of a sentence. This function inserts - /// scores. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// * `range` - The range of the sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict_partial_with_score( - &self, - mut sentence: Sentence, - range: Range, - ) -> Sentence { - let mut ys = vec![ScoreValue::default(); range.len()]; - self.predict_partial_impl(&sentence, range.clone(), &mut ys); - let mut scores = sentence - .boundary_scores - .take() - .unwrap_or_else(|| vec![0.; sentence.boundaries.len()]); - for (y, (b, s)) in ys.into_iter().zip( - sentence.boundaries[range.clone()] - .iter_mut() - .zip(scores[range].iter_mut()), - ) { - *b = if y >= ScoreValue::default() { + self.type_scorer + .add_scores(&sentence, &mut ys[self.padding..]); + for (&y, b) in ys[self.padding..] + .iter() + .zip(sentence.boundaries.iter_mut()) + { + *b = if y >= 0 { BoundaryType::WordBoundary } else { BoundaryType::NotWordBoundary }; - - #[cfg(feature = "model-quantize")] - let y = y as f64 * self.quantize_multiplier; - - *s = y; } - sentence.boundary_scores.replace(scores); + sentence.boundary_scores = ys; sentence } @@ -338,12 +121,9 @@ impl Predictor { /// /// A sentence with predicted boundary information. pub fn predict(&self, sentence: Sentence) -> Sentence { - let boundaries_size = sentence.boundaries.len(); - if boundaries_size == 0 { - sentence - } else { - self.predict_partial(sentence, 0..boundaries_size) - } + let mut sentence = self.predict_impl(sentence); + sentence.boundary_scores.clear(); + sentence } /// Predicts word boundaries. This function inserts scores. @@ -356,197 +136,113 @@ impl Predictor { /// /// A sentence with predicted boundary information. pub fn predict_with_score(&self, sentence: Sentence) -> Sentence { - let boundaries_size = sentence.boundaries.len(); - if boundaries_size == 0 { - sentence - } else { - self.predict_partial_with_score(sentence, 0..boundaries_size) - } + let mut sentence = self.predict_impl(sentence); + sentence.boundary_scores.rotate_left(self.padding); + sentence.boundary_scores.truncate(sentence.boundaries.len()); + sentence } - /// Sets the window size of words in the dictionary. - /// - /// # Arguments - /// - /// * `size` - The window size. - /// - /// # Returns - /// - /// A predictor with the specified window size. - pub fn dict_window_size(mut self, size: usize) -> Self { - self.dict_window_size = std::cmp::max(size, 1); - self + fn best_tag(&self, scores: &[i32]) -> Arc { + Arc::clone( + scores + .iter() + .zip(&self.tag_names) + .max_by_key(|(&x, _)| x) + .unwrap() + .1, + ) } - /// Creates a multithreading predictor. This function is the alias of - /// [`MultithreadPredictor::new()`]. - /// - /// # Arguments - /// - /// * `n_threads` - The number of threads. - /// * `chunk_size` - The chunk size of each thread. - /// - /// # Returns + /// Fills tags using calculated scores. /// - /// A multithread predictor. - #[cfg(feature = "multithreading")] - #[cfg_attr(docsrs, doc(cfg(feature = "multithreading")))] - pub fn multithreading(self, n_threads: usize, chunk_size: usize) -> MultithreadPredictor { - MultithreadPredictor::new(self, n_threads, chunk_size) - } -} - -/// Predictor for multithreading. -#[cfg(feature = "multithreading")] -#[cfg_attr(docsrs, doc(cfg(feature = "multithreading")))] -pub struct MultithreadPredictor { - task_tx: Sender<(Arc, Range, Vec)>, - result_rx: Receiver<(Vec, Range)>, - chunk_size: usize, - ys_pool: RefCell>>, - - #[cfg(feature = "model-quantize")] - quantize_multiplier: f64, -} - -#[cfg(feature = "multithreading")] -impl MultithreadPredictor { - /// Creates a multithreading predictor. + /// Tags are predicted using token boundaries, so you have to apply boundary post-processors + /// before filling tags. /// /// # Arguments /// - /// * `predictor` - A normal predictor. - /// * `n_threads` - The number of threads. - /// * `chunk_size` - The chunk size of each thread. + /// * `sentence` - A sentence. /// /// # Returns /// - /// A multithread predictor. - pub fn new(predictor: Predictor, n_threads: usize, chunk_size: usize) -> Self { - let predictor = Arc::new(predictor); - - let (result_tx, result_rx) = crossbeam_channel::unbounded(); - let (task_tx, task_rx) = - crossbeam_channel::unbounded::<(Arc, Range, Vec)>(); - for _ in 0..n_threads { - let predictor = Arc::clone(&predictor); - let result_tx = result_tx.clone(); - let task_rx = task_rx.clone(); - thread::spawn(move || { - for (sentence, range, mut ys) in task_rx { - predictor.predict_partial_impl( - &sentence, - range.clone(), - &mut ys[..range.len()], - ); - std::mem::drop(sentence); - result_tx.send((ys, range)).unwrap(); - } - }); + /// A sentence with tag information. When the predictor is instantiated with + /// `predict_tag = false`, the sentence is returned without any modification. + pub fn fill_tags(&self, mut sentence: Sentence) -> Sentence { + if self.tag_names.is_empty() { + return sentence; } - - Self { - task_tx, - result_rx, - chunk_size, - ys_pool: RefCell::new(vec![]), - - #[cfg(feature = "model-quantize")] - quantize_multiplier: predictor.quantize_multiplier, + if sentence.tags.is_empty() { + sentence.tags.resize(sentence.chars().len(), None); } - } - - /// Predicts word boundaries. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict(&self, sentence: Sentence) -> Sentence { - let sentence = Arc::new(sentence); - - let mut n_chunks = 0; - let mut ys_pool = self.ys_pool.borrow_mut(); - for start in (0..sentence.boundaries.len()).step_by(self.chunk_size) { - let ys = ys_pool - .pop() - .unwrap_or_else(|| vec![ScoreValue::default(); self.chunk_size]); - let sentence = Arc::clone(&sentence); - let end = std::cmp::min(start + self.chunk_size, sentence.boundaries.len()); - self.task_tx.send((sentence, start..end, ys)).unwrap(); - n_chunks += 1; + let n_tags = self.tag_names.len(); + let mut tag_score = self.tag_bias.clone(); + let mut left_scores_iter = sentence.tag_scores.left_scores.chunks(n_tags); + for (t, l) in tag_score.iter_mut().zip(left_scores_iter.next().unwrap()) { + *t += l; } - let mut boundaries = vec![BoundaryType::Unknown; sentence.boundaries.len()]; - for _ in 0..n_chunks { - let (ys, range) = self.result_rx.recv().unwrap(); - for (&y, b) in ys.iter().zip(&mut boundaries[range]) { - *b = if y >= ScoreValue::default() { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }; + let mut right_scores_iter = sentence.tag_scores.right_scores.chunks(n_tags); + let mut last_boundary_idx = 0; + for (i, ((((b, left_scores), right_scores), self_scores), tag)) in sentence + .boundaries + .iter() + .zip(left_scores_iter) + .zip(&mut right_scores_iter) + .zip(&sentence.tag_scores.self_scores) + .zip(&mut sentence.tags) + .enumerate() + { + if *b == BoundaryType::WordBoundary { + for (t, r) in tag_score.iter_mut().zip(right_scores) { + *t += *r; + } + if let Some(self_weights) = self_scores.as_ref() { + let diff = last_boundary_idx as i32 - i as i32 - 1; + for self_weight in self_weights.iter() { + match self_weight.start_rel_position.cmp(&diff) { + Ordering::Greater => continue, + Ordering::Equal => { + for (t, s) in tag_score.iter_mut().zip(&self_weight.weight) { + *t += *s; + } + } + Ordering::Less => (), + } + break; + } + } + tag.replace(self.best_tag(&tag_score)); + for (t, (l, b)) in tag_score + .iter_mut() + .zip(left_scores.iter().zip(&self.tag_bias)) + { + *t = *l + *b; + } + last_boundary_idx = i + 1; } - ys_pool.push(ys); } - - let mut sentence = Arc::try_unwrap(sentence).unwrap(); - sentence.boundaries = boundaries; - sentence - } - - /// Predicts word boundaries. This function inserts scores. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict_with_score(&self, mut sentence: Sentence) -> Sentence { - let mut scores = sentence - .boundary_scores - .take() - .unwrap_or_else(|| vec![0.; sentence.boundaries.len()]); - let sentence = Arc::new(sentence); - let mut n_chunks = 0; - let mut ys_pool = self.ys_pool.borrow_mut(); - for start in (0..sentence.boundaries.len()).step_by(self.chunk_size) { - let ys = ys_pool - .pop() - .unwrap_or_else(|| vec![ScoreValue::default(); self.chunk_size]); - let sentence = Arc::clone(&sentence); - let end = std::cmp::min(start + self.chunk_size, sentence.boundaries.len()); - self.task_tx.send((sentence, start..end, ys)).unwrap(); - n_chunks += 1; + for (t, r) in tag_score.iter_mut().zip(right_scores_iter.next().unwrap()) { + *t += r; } - let mut boundaries = vec![BoundaryType::Unknown; sentence.boundaries.len()]; - for _ in 0..n_chunks { - let (ys, range) = self.result_rx.recv().unwrap(); - for (&y, (b, s)) in ys - .iter() - .zip(boundaries[range.clone()].iter_mut().zip(&mut scores[range])) - { - *b = if y >= ScoreValue::default() { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }; - - #[cfg(feature = "model-quantize")] - let y = y as f64 * self.quantize_multiplier; - - *s = y; + if let Some(self_weights) = sentence.tag_scores.self_scores.last().unwrap().as_ref() { + let diff = last_boundary_idx as i32 - sentence.chars.len() as i32; + for self_weight in self_weights.iter() { + match self_weight.start_rel_position.cmp(&diff) { + Ordering::Greater => continue, + Ordering::Equal => { + for (t, s) in tag_score.iter_mut().zip(&self_weight.weight) { + *t += *s; + } + } + Ordering::Less => (), + } + break; } - ys_pool.push(ys); } + sentence + .tags + .last_mut() + .unwrap() + .replace(self.best_tag(&tag_score)); - let mut sentence = Arc::try_unwrap(sentence).unwrap(); - sentence.boundaries = boundaries; - sentence.boundary_scores.replace(scores); sentence } } @@ -555,6 +251,11 @@ impl MultithreadPredictor { mod tests { use super::*; + use crate::dict_model::{DictModel, DictWeight, WordWeightRecord}; + use crate::ngram_model::{NgramData, NgramModel}; + use crate::sentence::Token; + use crate::tag_model::{TagClassInfo, TagModel}; + /// Input: 我 ら は 全 世 界 の 国 民 /// bias: -200 .. .. .. .. .. .. .. /// words: @@ -583,84 +284,81 @@ mod tests { /// 世: 40 42 fn generate_model_1() -> Model { Model { - words: vec![ - "我ら".as_bytes().to_vec(), - "全世界".as_bytes().to_vec(), - "国民".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "界".as_bytes().to_vec(), - ], - types: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], - dict: vec![ - "全世界".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "世".as_bytes().to_vec(), - ], - #[cfg(not(feature = "model-quantize"))] - word_weights: vec![ - vec![0.5, 1.0, 1.5, 2.0, 2.5], - vec![3.0, 3.5, 4.0, 4.5], - vec![5.0, 5.5, 6.0, 6.5, 7.0], - vec![7.5, 8.0, 8.5, 9.0, 9.5], - vec![10.0, 10.5, 11.0, 11.5, 12.0, 12.5], - ], - #[cfg(feature = "model-quantize")] - word_weights: vec![ - vec![1, 2, 3, 4, 5], - vec![6, 7, 8, 9], - vec![10, 11, 12, 13, 14], - vec![15, 16, 17, 18, 19], - vec![20, 21, 22, 23, 24, 25], - ], - #[cfg(not(feature = "model-quantize"))] - type_weights: vec![ - vec![13.0, 13.5, 14.0, 14.5], - vec![15.0, 15.5, 16.0, 16.5], - vec![17.0, 17.5, 18.0], - vec![18.5, 19.0, 19.5], - ], - #[cfg(feature = "model-quantize")] - type_weights: vec![ - vec![26, 27, 28, 29], - vec![30, 31, 32, 33], - vec![34, 35, 36], - vec![37, 38, 39], - ], - #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![ - DictWeight { - right: 20.0, - inner: 20.5, - left: 21.0, + char_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: "我ら".to_string(), + weights: vec![1, 2, 3, 4, 5], }, - DictWeight { - right: 21.5, - inner: 22.0, - left: 22.5, + NgramData { + ngram: "全世界".to_string(), + weights: vec![6, 7, 8, 9], }, - ], - #[cfg(feature = "model-quantize")] - dict_weights: vec![ - DictWeight { - right: 40, - inner: 41, - left: 42, + NgramData { + ngram: "国民".to_string(), + weights: vec![10, 11, 12, 13, 14], }, - DictWeight { - right: 43, - inner: 44, - left: 45, + NgramData { + ngram: "世界".to_string(), + weights: vec![15, 16, 17, 18, 19], }, - ], - #[cfg(feature = "model-quantize")] - quantize_multiplier: 0.5, - dict_word_wise: false, - #[cfg(not(feature = "model-quantize"))] - bias: -100.0, - #[cfg(feature = "model-quantize")] + NgramData { + ngram: "界".to_string(), + weights: vec![20, 21, 22, 23, 24, 25], + }, + ]), + type_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: b"H".to_vec(), + weights: vec![26, 27, 28, 29], + }, + NgramData { + ngram: b"K".to_vec(), + weights: vec![30, 31, 32, 33], + }, + NgramData { + ngram: b"KH".to_vec(), + weights: vec![34, 35, 36], + }, + NgramData { + ngram: b"HK".to_vec(), + weights: vec![37, 38, 39], + }, + ]), + dict_model: DictModel { + dict: vec![ + WordWeightRecord { + word: "全世界".to_string(), + weights: DictWeight { + right: 43, + inside: 44, + left: 45, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世界".to_string(), + weights: DictWeight { + right: 43, + inside: 44, + left: 45, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世".to_string(), + weights: DictWeight { + right: 40, + inside: 41, + left: 42, + }, + comment: "".to_string(), + }, + ], + }, bias: -200, char_window_size: 3, type_window_size: 2, + tag_model: TagModel::default(), } } @@ -692,94 +390,81 @@ mod tests { /// 世: 38 40 fn generate_model_2() -> Model { Model { - words: vec![ - "我ら".as_bytes().to_vec(), - "全世界".as_bytes().to_vec(), - "国民".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "界".as_bytes().to_vec(), - ], - types: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], - dict: vec![ - "全世界".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "世".as_bytes().to_vec(), - ], - #[cfg(not(feature = "model-quantize"))] - word_weights: vec![ - vec![0.25, 0.5, 0.75], - vec![1.0, 1.25], - vec![1.5, 1.75, 2.0], - vec![2.25, 2.5, 2.75], - vec![3.0, 3.25, 3.5, 3.75], - ], - #[cfg(feature = "model-quantize")] - word_weights: vec![ - vec![1, 2, 3], - vec![4, 5], - vec![6, 7, 8], - vec![9, 10, 11], - vec![12, 13, 14, 15], - ], - #[cfg(not(feature = "model-quantize"))] - type_weights: vec![ - vec![4.0, 4.25, 4.5, 4.75, 5.0, 5.25], - vec![5.5, 5.75, 6.0, 6.25, 6.5, 6.75], - vec![7.0, 7.25, 7.5, 7.75, 8.0], - vec![8.25, 8.5, 8.75, 9.0, 9.25], - ], - #[cfg(feature = "model-quantize")] - type_weights: vec![ - vec![16, 17, 18, 19, 20, 21], - vec![22, 23, 24, 25, 26, 27], - vec![28, 29, 30, 31, 32], - vec![33, 34, 35, 36, 37], - ], - #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![ - DictWeight { - right: 9.5, - inner: 9.75, - left: 10.0, + char_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: "我ら".to_string(), + weights: vec![1, 2, 3], }, - DictWeight { - right: 10.25, - inner: 10.5, - left: 10.75, + NgramData { + ngram: "全世界".to_string(), + weights: vec![4, 5], }, - DictWeight { - right: 11.0, - inner: 11.25, - left: 11.5, + NgramData { + ngram: "国民".to_string(), + weights: vec![6, 7, 8], }, - ], - #[cfg(feature = "model-quantize")] - dict_weights: vec![ - DictWeight { - right: 38, - inner: 39, - left: 40, + NgramData { + ngram: "世界".to_string(), + weights: vec![9, 10, 11], }, - DictWeight { - right: 41, - inner: 42, - left: 43, + NgramData { + ngram: "界".to_string(), + weights: vec![12, 13, 14, 15], }, - DictWeight { - right: 44, - inner: 45, - left: 46, + ]), + type_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: b"H".to_vec(), + weights: vec![16, 17, 18, 19, 20, 21], }, - ], - #[cfg(feature = "model-quantize")] - quantize_multiplier: 0.25, - dict_word_wise: false, - #[cfg(not(feature = "model-quantize"))] - bias: -71.25, - #[cfg(feature = "model-quantize")] + NgramData { + ngram: b"K".to_vec(), + weights: vec![22, 23, 24, 25, 26, 27], + }, + NgramData { + ngram: b"KH".to_vec(), + weights: vec![28, 29, 30, 31, 32], + }, + NgramData { + ngram: b"HK".to_vec(), + weights: vec![33, 34, 35, 36, 37], + }, + ]), + dict_model: DictModel { + dict: vec![ + WordWeightRecord { + word: "全世界".to_string(), + weights: DictWeight { + right: 44, + inside: 45, + left: 46, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世界".to_string(), + weights: DictWeight { + right: 41, + inside: 42, + left: 43, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世".to_string(), + weights: DictWeight { + right: 38, + inside: 39, + left: 40, + }, + comment: "".to_string(), + }, + ], + }, bias: -285, char_window_size: 2, type_window_size: 3, + tag_model: TagModel::default(), } } @@ -811,101 +496,343 @@ mod tests { /// 世: 44 46 fn generate_model_3() -> Model { Model { - words: vec![ - "我ら".as_bytes().to_vec(), - "全世界".as_bytes().to_vec(), - "国民".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "界".as_bytes().to_vec(), - ], - types: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], - dict: vec![ - "国民".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "世".as_bytes().to_vec(), - ], - #[cfg(not(feature = "model-quantize"))] - word_weights: vec![ - vec![0.25, 0.5, 0.75], - vec![1.0, 1.25], - vec![1.5, 1.75, 2.0], - vec![2.25, 2.5, 2.75], - vec![3.0, 3.25, 3.5, 3.75], - ], - #[cfg(feature = "model-quantize")] - word_weights: vec![ - vec![1, 2, 3], - vec![4, 5], - vec![6, 7, 8], - vec![9, 10, 11], - vec![12, 13, 14, 15], - ], - #[cfg(not(feature = "model-quantize"))] - type_weights: vec![ - vec![4.0, 4.25, 4.5, 4.75, 5.0, 5.25], - vec![5.5, 5.75, 6.0, 6.25, 6.5, 6.75], - vec![7.0, 7.25, 7.5, 7.75, 8.0], - vec![8.25, 8.5, 8.75, 9.0, 9.25], - ], - #[cfg(feature = "model-quantize")] - type_weights: vec![ - vec![16, 17, 18, 19, 20, 21], - vec![22, 23, 24, 25, 26, 27], - vec![28, 29, 30, 31, 32], - vec![33, 34, 35, 36, 37], - ], - #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![ - DictWeight { - right: 9.5, - inner: 9.75, - left: 11.0, + char_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: "我ら".to_string(), + weights: vec![1, 2, 3], }, - DictWeight { - right: 10.25, - inner: 10.5, - left: 10.75, + NgramData { + ngram: "全世界".to_string(), + weights: vec![4, 5], }, - DictWeight { - right: 11.0, - inner: 11.25, - left: 11.5, + NgramData { + ngram: "国民".to_string(), + weights: vec![6, 7, 8], }, - ], - #[cfg(feature = "model-quantize")] - dict_weights: vec![ - DictWeight { - right: 38, - inner: 39, - left: 40, + NgramData { + ngram: "世界".to_string(), + weights: vec![9, 10, 11], }, - DictWeight { - right: 41, - inner: 42, - left: 43, + NgramData { + ngram: "界".to_string(), + weights: vec![12, 13, 14, 15], }, - DictWeight { - right: 44, - inner: 45, - left: 46, + ]), + type_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: b"H".to_vec(), + weights: vec![16, 17, 18, 19, 20, 21], }, - ], - #[cfg(feature = "model-quantize")] - quantize_multiplier: 0.25, - dict_word_wise: true, - #[cfg(not(feature = "model-quantize"))] - bias: -71.25, - #[cfg(feature = "model-quantize")] + NgramData { + ngram: b"K".to_vec(), + weights: vec![22, 23, 24, 25, 26, 27], + }, + NgramData { + ngram: b"KH".to_vec(), + weights: vec![28, 29, 30, 31, 32], + }, + NgramData { + ngram: b"HK".to_vec(), + weights: vec![33, 34, 35, 36, 37], + }, + ]), + dict_model: DictModel { + dict: vec![ + WordWeightRecord { + word: "国民".to_string(), + weights: DictWeight { + right: 38, + inside: 39, + left: 40, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世界".to_string(), + weights: DictWeight { + right: 41, + inside: 42, + left: 43, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世".to_string(), + weights: DictWeight { + right: 44, + inside: 45, + left: 46, + }, + comment: "".to_string(), + }, + ], + }, bias: -285, char_window_size: 2, type_window_size: 3, + tag_model: TagModel::default(), + } + } + + /// Input: 我 ら は 全 世 界 の 国 民 + /// bias: -200 .. .. .. .. .. .. .. + /// chars: + /// 我ら: 3 4 5 + /// 全世界: 6 7 8 9 + /// 国民: 10 11 12 + /// 世界: 15 16 17 18 19 + /// 界: 20 21 22 23 24 25 + /// types: + /// H: 27 28 29 + /// 26 27 28 29 + /// 26 27 28 29 + /// K: 32 33 + /// 30 31 32 33 + /// 30 31 32 33 + /// 30 31 32 33 + /// 30 31 32 + /// 30 31 + /// KH: 35 36 + /// 34 35 36 + /// HK: 37 38 39 + /// 37 38 39 + /// dict: + /// 全世界: 43 44 44 45 + /// 世界: 43 44 45 + /// 世: 40 42 + /// 世界の国民: 43 44 44 44 44 + /// は全世界: 43 44 44 44 45 + /// + /// + /// は全世界: 43 44 44 44 45 + /// 15 16 17 18 19 + /// 20 21 22 23 24 25 + /// 6 7 8 9 + fn generate_model_4() -> Model { + Model { + char_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: "我ら".to_string(), + weights: vec![1, 2, 3, 4, 5], + }, + NgramData { + ngram: "全世界".to_string(), + weights: vec![6, 7, 8, 9], + }, + NgramData { + ngram: "国民".to_string(), + weights: vec![10, 11, 12, 13, 14], + }, + NgramData { + ngram: "世界".to_string(), + weights: vec![15, 16, 17, 18, 19], + }, + NgramData { + ngram: "界".to_string(), + weights: vec![20, 21, 22, 23, 24, 25], + }, + ]), + type_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: b"H".to_vec(), + weights: vec![26, 27, 28, 29], + }, + NgramData { + ngram: b"K".to_vec(), + weights: vec![30, 31, 32, 33], + }, + NgramData { + ngram: b"KH".to_vec(), + weights: vec![34, 35, 36], + }, + NgramData { + ngram: b"HK".to_vec(), + weights: vec![37, 38, 39], + }, + ]), + dict_model: DictModel { + dict: vec![ + WordWeightRecord { + word: "全世界".to_string(), + weights: DictWeight { + right: 43, + inside: 44, + left: 45, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世界".to_string(), + weights: DictWeight { + right: 43, + inside: 44, + left: 45, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世".to_string(), + weights: DictWeight { + right: 40, + inside: 41, + left: 42, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "世界の国民".to_string(), + weights: DictWeight { + right: 43, + inside: 44, + left: 45, + }, + comment: "".to_string(), + }, + WordWeightRecord { + word: "は全世界".to_string(), + weights: DictWeight { + right: 43, + inside: 44, + left: 45, + }, + comment: "".to_string(), + }, + ], + }, + bias: -200, + char_window_size: 3, + type_window_size: 2, + tag_model: TagModel::default(), + } + } + + /// Input: 人 と 人 を つ な ぐ 人 + /// left: + /// \0人: 1 4 + /// 2 5 + /// 3 6 + /// 人: 7 10 7 10 + /// 8 11 8 11 + /// 9 12 9 12 + /// つなぐ: 13 16 19 + /// 14 17 20 + /// 15 18 21 + /// 人\0: 22 + /// 23 + /// 24 + /// + /// sum: 1 11 10 7 10 13 16 41 + /// 2 13 11 8 11 14 17 43 + /// 3 15 12 9 12 15 18 45 + /// + /// right: + /// \0人と: 28 + /// 29 + /// 30 + /// 人を: 31 34 37 + /// 32 35 38 + /// 33 36 39 + /// を: 40 43 + /// 41 44 + /// 42 45 + /// 人\0: 46 49 + /// 47 50 + /// 48 51 + /// + /// sum: 28 71 77 37 0 0 46 49 + /// 29 73 79 38 0 0 47 50 + /// 30 75 81 39 0 0 48 51 + fn generate_model_5() -> Model { + Model { + char_ngram_model: NgramModel::new(vec![NgramData { + ngram: "xxxx".to_string(), + weights: vec![0], + }]), + type_ngram_model: NgramModel::new(vec![NgramData { + ngram: b"RRRR".to_vec(), + weights: vec![0], + }]), + dict_model: DictModel { dict: vec![] }, + bias: 0, + char_window_size: 2, + type_window_size: 2, + tag_model: TagModel { + class_info: vec![ + TagClassInfo { + name: "名詞".to_string(), + bias: 5, + }, + TagClassInfo { + name: "動詞".to_string(), + bias: 3, + }, + TagClassInfo { + name: "助詞".to_string(), + bias: 1, + }, + ], + left_char_model: NgramModel::new(vec![ + NgramData { + ngram: "\0人".to_string(), + weights: vec![1, 2, 3, 4, 5, 6], + }, + NgramData { + ngram: "人".to_string(), + weights: vec![7, 8, 9, 10, 11, 12], + }, + NgramData { + ngram: "つなぐ".to_string(), + weights: vec![13, 14, 15, 16, 17, 18, 19, 20, 21], + }, + NgramData { + ngram: "ぐ人\0".to_string(), + weights: vec![22, 23, 24], + }, + ]), + right_char_model: NgramModel::new(vec![ + NgramData { + ngram: "\0人と".to_string(), + weights: vec![25, 26, 27, 28, 29, 30], + }, + NgramData { + ngram: "人を".to_string(), + weights: vec![31, 32, 33, 34, 35, 36, 37, 38, 39], + }, + NgramData { + ngram: "を".to_string(), + weights: vec![40, 41, 42, 43, 44, 45], + }, + NgramData { + ngram: "人\0".to_string(), + weights: vec![46, 47, 48, 49, 50, 51], + }, + ]), + self_char_model: NgramModel::new(vec![ + NgramData { + ngram: "人".to_string(), + weights: vec![2, -1, -1], + }, + NgramData { + ngram: "と".to_string(), + weights: vec![0, 0, 0], + }, + NgramData { + ngram: "つなぐ".to_string(), + weights: vec![0, 1, 0], + }, + NgramData { + ngram: "を".to_string(), + weights: vec![0, 0, 0], + }, + ]), + }, } } #[test] fn test_predict_1() { let model = generate_model_1(); - let p = Predictor::new(model); + let p = Predictor::new(model, false).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict(s); assert_eq!( @@ -926,7 +853,7 @@ mod tests { #[test] fn test_predict_2() { let model = generate_model_2(); - let p = Predictor::new(model); + let p = Predictor::new(model, false).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict(s); assert_eq!( @@ -947,7 +874,7 @@ mod tests { #[test] fn test_predict_3() { let model = generate_model_3(); - let p = Predictor::new(model); + let p = Predictor::new(model, false).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict(s); assert_eq!( @@ -965,12 +892,34 @@ mod tests { ); } + #[test] + fn test_predict_4() { + let model = generate_model_4(); + let p = Predictor::new(model, false).unwrap(); + let s = Sentence::from_raw("我らは全世界の国民").unwrap(); + let s = p.predict(s); + assert_eq!( + &[ + BoundaryType::NotWordBoundary, + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + ], + s.boundaries(), + ); + } + #[test] fn test_predict_with_score_1() { let model = generate_model_1(); - let p = Predictor::new(model); + let p = Predictor::new(model, false).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict_with_score(s); + assert_eq!(&[-77, -5, 45, 132, 133, 144, 50, -32], s.boundary_scores(),); assert_eq!( &[ BoundaryType::NotWordBoundary, @@ -984,18 +933,18 @@ mod tests { ], s.boundaries(), ); - assert_eq!( - &[-38.5, -2.5, 22.5, 66.0, 66.5, 72.0, 25.0, -16.0], - s.boundary_scores().unwrap(), - ); } #[test] fn test_predict_with_score_2() { let model = generate_model_2(); - let p = Predictor::new(model); + let p = Predictor::new(model, false).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict_with_score(s); + assert_eq!( + &[-138, -109, -39, 57, 104, 34, -79, -114], + s.boundary_scores(), + ); assert_eq!( &[ BoundaryType::NotWordBoundary, @@ -1009,18 +958,18 @@ mod tests { ], s.boundaries(), ); - assert_eq!( - &[-34.5, -27.25, -9.75, 14.25, 26.0, 8.5, -19.75, -28.5], - s.boundary_scores().unwrap(), - ); } #[test] fn test_predict_with_score_3() { let model = generate_model_3(); - let p = Predictor::new(model); + let p = Predictor::new(model, false).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict_with_score(s); + assert_eq!( + &[-138, -109, -83, 18, 65, -12, -41, -75], + s.boundary_scores(), + ); assert_eq!( &[ BoundaryType::NotWordBoundary, @@ -1034,147 +983,90 @@ mod tests { ], s.boundaries(), ); - assert_eq!( - &[-34.5, -27.25, -20.75, 4.5, 16.25, -3.0, -10.25, -18.75], - s.boundary_scores().unwrap(), - ); } #[test] - fn test_predict_partial_1() { - let model = generate_model_1(); - let p = Predictor::new(model); + fn test_predict_with_score_4() { + let model = generate_model_4(); + let p = Predictor::new(model, false).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial(s, 1..5); + let s = p.predict_with_score(s); + assert_eq!(&[-77, 38, 89, 219, 221, 233, 94, 12], s.boundary_scores(),); assert_eq!( &[ - BoundaryType::Unknown, BoundaryType::NotWordBoundary, BoundaryType::WordBoundary, BoundaryType::WordBoundary, BoundaryType::WordBoundary, - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::Unknown, - ], - s.boundaries(), - ); - } - - #[test] - fn test_predict_partial_2() { - let model = generate_model_2(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial(s, 2..7); - assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, BoundaryType::WordBoundary, BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::Unknown, - ], - s.boundaries(), - ); - } - - #[test] - fn test_predict_partial_3() { - let model = generate_model_3(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial(s, 2..6); - assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, BoundaryType::WordBoundary, BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::Unknown, - BoundaryType::Unknown, ], s.boundaries(), ); } #[test] - fn test_predict_partial_with_score_1() { - let model = generate_model_1(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial_with_score(s, 1..5); + fn test_predict_with_score_5() { + let model = generate_model_5(); + let p = Predictor::new(model, true).unwrap(); + let s = Sentence::from_raw("人と人をつなぐ人").unwrap(); + let mut s = p.predict(s); assert_eq!( &[ - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::Unknown, + 1, 2, 3, 11, 13, 15, 10, 11, 12, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 41, + 43, 45 ], - s.boundaries(), + s.tag_scores.left_scores.as_slice() ); - assert_eq!( - &[0.0, -2.5, 22.5, 66.0, 66.5, 0.0, 0.0, 0.0], - s.boundary_scores().unwrap(), - ); - } - - #[test] - fn test_predict_partial_with_score_2() { - let model = generate_model_2(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial_with_score(s, 2..7); assert_eq!( &[ - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::Unknown, + 28, 29, 30, 71, 73, 75, 77, 79, 81, 37, 38, 39, 0, 0, 0, 0, 0, 0, 46, 47, 48, 49, + 50, 51 ], - s.boundaries(), + s.tag_scores.right_scores.as_slice() ); - assert_eq!( - &[0.0, 0.0, -9.75, 14.25, 26.0, 8.5, -19.75, 0.0], - s.boundary_scores().unwrap(), - ); - } - #[test] - fn test_predict_partial_with_score_3() { - let model = generate_model_3(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial_with_score(s, 2..6); + s.boundaries_mut().copy_from_slice(&[ + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + BoundaryType::WordBoundary, + BoundaryType::NotWordBoundary, + BoundaryType::NotWordBoundary, + BoundaryType::WordBoundary, + ]); + let s = p.fill_tags(s); + assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::Unknown, - BoundaryType::Unknown, + vec![ + Token { + surface: "人", + tag: Some("名詞") + }, + Token { + surface: "と", + tag: Some("助詞") + }, + Token { + surface: "人", + tag: Some("名詞") + }, + Token { + surface: "を", + tag: Some("助詞") + }, + Token { + surface: "つなぐ", + tag: Some("動詞") + }, + Token { + surface: "人", + tag: Some("名詞") + } ], - s.boundaries(), - ); - assert_eq!( - &[0.0, 0.0, -20.75, 4.5, 16.25, -3.0, 0.0, 0.0], - s.boundary_scores().unwrap(), + s.to_tokenized_vec().unwrap(), ); } } diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index 9f7804b4..a3c0cf41 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -1,4 +1,6 @@ -use anyhow::{anyhow, Result}; +use std::sync::Arc; + +use crate::errors::{Result, VaporettoError}; /// Character type. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -76,43 +78,403 @@ pub enum BoundaryType { Unknown = 2, } +/// Token information. +#[derive(Debug, PartialEq, Clone)] +pub struct Token<'a> { + /// A surface of this token. + pub surface: &'a str, + + /// A part-of-speech tag of this token. + pub tag: Option<&'a str>, +} + +/// Weight array with the corresponding range. +/// +/// This data is placed on the end of each range. +#[derive(Debug, PartialEq, Clone)] +pub struct TagRangeScore { + /// The relative position of the start position from the end position. + pub start_rel_position: i32, + + /// Weight array. + pub weight: Vec, +} + +impl TagRangeScore { + #[allow(dead_code)] + pub fn new(start_rel_position: i32, weight: Vec) -> Self { + Self { + start_rel_position, + weight, + } + } +} + +pub type TagRangeScores = Arc>; + +#[derive(Debug, PartialEq, Clone, Default)] +pub struct TagScores { + pub left_scores: Vec, + pub right_scores: Vec, + pub self_scores: Vec>, +} + +impl TagScores { + /// Clears scores. + pub fn clear(&mut self) { + self.left_scores.clear(); + self.right_scores.clear(); + self.self_scores.clear(); + } + + /// Initializes score arrays. + /// + /// # Arguments + /// + /// * `n_chars` - Length of characters in code points. + /// * `n_tags` - The number of tags. + #[allow(dead_code)] + pub fn init(&mut self, n_chars: usize, n_tags: usize) { + self.clear(); + self.left_scores.resize(n_chars * n_tags, 0); + self.right_scores.resize(n_chars * n_tags, 0); + self.self_scores.resize(n_chars, None); + } +} + /// Sentence with boundary annotations. #[derive(Debug, PartialEq, Clone)] pub struct Sentence { pub(crate) text: String, + pub(crate) chars: Vec, pub(crate) str_to_char_pos: Vec, pub(crate) char_to_str_pos: Vec, pub(crate) char_type: Vec, pub(crate) boundaries: Vec, - pub(crate) boundary_scores: Option>, + pub(crate) boundary_scores: Vec, + pub(crate) tag_scores: TagScores, + pub(crate) tags: Vec>>, } impl Sentence { - fn common_info(chars: &[char]) -> (Vec, Vec, Vec) { - let mut char_to_str_pos = Vec::with_capacity(chars.len() + 1); - let mut char_type = Vec::with_capacity(chars.len()); + fn internal_new( + text: String, + chars: Vec, + boundaries: Vec, + tags: Vec>>, + ) -> Self { + let mut s = Self { + text, + chars, + str_to_char_pos: vec![], + char_to_str_pos: vec![], + char_type: vec![], + boundaries, + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags, + }; + s.update_common_info(); + s + } + + fn clear(&mut self) { + self.text.clear(); + self.text.push(' '); + self.chars.clear(); + self.chars.push(' '); + self.str_to_char_pos.clear(); + self.str_to_char_pos.push(0); + self.str_to_char_pos.push(1); + self.char_to_str_pos.clear(); + self.char_to_str_pos.push(0); + self.char_to_str_pos.push(1); + self.char_type.clear(); + self.char_type.push(CharacterType::Other as u8); + self.boundaries.clear(); + self.boundary_scores.clear(); + self.tag_scores.clear(); + self.tags.clear(); + self.tags.push(None); + } + + fn parse_raw_text( + raw_text: &str, + chars: &mut Vec, + boundaries: &mut Vec, + tags: &mut Vec>>, + ) -> Result<()> { + if raw_text.is_empty() { + return Err(VaporettoError::invalid_argument( + "raw_text", + "must contain at least one character", + )); + } + + chars.clear(); + + for c in raw_text.chars() { + if c == '\0' { + return Err(VaporettoError::invalid_argument( + "raw_text", + "must not contain NULL", + )); + } + chars.push(c); + } + boundaries.clear(); + boundaries.resize(chars.len() - 1, BoundaryType::Unknown); + tags.clear(); + tags.resize(chars.len(), None); + + Ok(()) + } + + fn parse_tokenized_text( + tokenized_text: &str, + text: &mut String, + chars: &mut Vec, + boundaries: &mut Vec, + tags: &mut Vec>>, + ) -> Result<()> { + if tokenized_text.is_empty() { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "must contain at least one character", + )); + } + + text.clear(); + text.reserve(tokenized_text.len()); + chars.clear(); + boundaries.clear(); + tags.clear(); + + let mut tag_str_tmp = None; + let mut tag_str = None; + let mut prev_boundary = false; + let mut escape = false; + for c in tokenized_text.chars() { + match (escape, c) { + // escape a following character + (false, '\\') => { + escape = true; + } + // token boundary + (false, ' ') => { + if chars.is_empty() { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "must not start with a whitespace", + )); + } + if prev_boundary { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "must not contain consecutive whitespaces", + )); + } + prev_boundary = true; + tag_str = tag_str_tmp.take(); + } + // POS tag + (false, '/') => { + if chars.is_empty() { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "must not start with a slash", + )); + } + if prev_boundary { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "a slash must follow a character", + )); + } + tag_str_tmp.replace("".to_string()); + } + // escaped character or other character + (_, _) => { + if let Some(tag) = tag_str_tmp.as_mut() { + tag.push(c); + continue; + } + if !chars.is_empty() { + boundaries.push(if prev_boundary { + BoundaryType::WordBoundary + } else { + BoundaryType::NotWordBoundary + }); + tags.push(tag_str.take().map(Arc::new)); + } + if c == '\0' { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "must not contain NULL", + )); + } + prev_boundary = false; + escape = false; + text.push(c); + chars.push(c); + } + }; + } + + if prev_boundary { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "must not end with a whitespace", + )); + } + tags.push(tag_str_tmp.take().map(Arc::new)); + + Ok(()) + } + + fn parse_partial_annotation( + labeled_text: &str, + text: &mut String, + chars: &mut Vec, + boundaries: &mut Vec, + tags: &mut Vec>>, + ) -> Result<()> { + if labeled_text.is_empty() { + return Err(VaporettoError::invalid_argument( + "labeled_text", + "must contain at least one character", + )); + } + + let labeled_chars: Vec = labeled_text.chars().collect(); + + text.clear(); + chars.clear(); + boundaries.clear(); + tags.clear(); + + let mut tag_str = None; + let mut is_char = true; + let mut fixed_token = true; + for &c in &labeled_chars { + if is_char { + if c == '\0' { + return Err(VaporettoError::invalid_argument( + "labeled_text", + "must not contain NULL", + )); + } + text.push(c); + chars.push(c); + is_char = false; + continue; + } + match c { + // unannotated boundary + ' ' => { + if tag_str.is_some() { + return Err(VaporettoError::invalid_argument( + "labeled_text", + "POS tag must be annotated to a token".to_string(), + )); + } + tags.push(None); + boundaries.push(BoundaryType::Unknown); + is_char = true; + fixed_token = false; + } + // token boundary + '|' => { + if !fixed_token && tag_str.is_some() { + return Err(VaporettoError::invalid_argument( + "labeled_text", + "POS tag must be annotated to a token".to_string(), + )); + } + tags.push(tag_str.take().map(Arc::new)); + boundaries.push(BoundaryType::WordBoundary); + is_char = true; + fixed_token = true; + } + // not token boundary + '-' => { + if tag_str.is_some() { + return Err(VaporettoError::invalid_argument( + "labeled_text", + "POS tag must be annotated to a token".to_string(), + )); + } + tags.push(None); + boundaries.push(BoundaryType::NotWordBoundary); + is_char = true; + } + // POS tag + '/' => { + tag_str.replace("".to_string()); + } + _ => { + if let Some(tag) = tag_str.as_mut() { + tag.push(c); + } else { + return Err(VaporettoError::invalid_argument( + "labeled_text", + format!("contains an invalid boundary character: '{}'", c), + )); + } + } + } + } + tags.push(tag_str.take().map(Arc::new)); + if chars.len() != boundaries.len() + 1 { + return Err(VaporettoError::invalid_argument( + "labeled_text", + "invalid annotation".to_string(), + )); + } + + Ok(()) + } + + /// Updates char_to_str_pos, str_to_char_pos, and char_type. + /// + /// This function allocates: + /// + /// * char_to_str_pos: chars.len() + 1 + /// * str_to_char_pos: text.len() + 1 + /// * char_type: chars.len() + /// + /// If these variables already have sufficient spaces, this function reuses them. + fn update_common_info(&mut self) { + self.char_to_str_pos.clear(); + self.str_to_char_pos.clear(); + self.char_type.clear(); + self.boundary_scores.clear(); + self.tag_scores.clear(); + let mut pos = 0; - char_to_str_pos.push(0); - for &c in chars { + self.char_to_str_pos.push(0); + for &c in &self.chars { pos += c.len_utf8(); - char_to_str_pos.push(pos); - char_type.push(CharacterType::get_type(c) as u8) + self.char_to_str_pos.push(pos); + self.char_type.push(CharacterType::get_type(c) as u8) } - let mut str_to_char_pos = vec![0; char_to_str_pos.last().unwrap_or(&0) + 1]; - for (i, &j) in char_to_str_pos.iter().enumerate() { - // j < str_to_char_pos.len() + + debug_assert!(pos == self.text.len()); + + self.str_to_char_pos.resize(self.text.len() + 1, 0); + for (i, &j) in self.char_to_str_pos.iter().enumerate() { + // j is always lower than pos + 1, so the following is safe. unsafe { - *str_to_char_pos.get_unchecked_mut(j) = i; + *self.str_to_char_pos.get_unchecked_mut(j) = i; } } - (char_to_str_pos, str_to_char_pos, char_type) } /// Creates a new [`Sentence`] from a given string. /// /// # Arguments /// - /// * `text` - A raw string without any annotation. + /// * `raw_text` - A raw string without any annotation. /// /// # Returns /// @@ -120,7 +482,7 @@ impl Sentence { /// /// # Errors /// - /// If the given `text` is empty, an error variant will be returned. + /// If the given `raw_text` is empty, an error variant will be returned. /// /// # Examples /// @@ -133,29 +495,62 @@ impl Sentence { /// let s = Sentence::from_raw(""); /// assert!(s.is_err()); /// ``` - pub fn from_raw(text: S) -> Result + pub fn from_raw(raw_text: S) -> Result where S: Into, { - let text = text.into(); + let raw_text = raw_text.into(); - if text.is_empty() { - return Err(anyhow!("`text` is empty")); - } - - let chars: Vec = text.chars().collect(); - let boundaries = vec![BoundaryType::Unknown; chars.len() - 1]; + let mut chars = Vec::with_capacity(0); + let mut boundaries = Vec::with_capacity(0); + let mut tags = Vec::with_capacity(0); + Self::parse_raw_text(&raw_text, &mut chars, &mut boundaries, &mut tags)?; - let (char_to_str_pos, str_to_char_pos, char_type) = Self::common_info(&chars); + Ok(Self::internal_new(raw_text, chars, boundaries, tags)) + } - Ok(Self { - text, - str_to_char_pos, - char_to_str_pos, - char_type, - boundaries, - boundary_scores: None, - }) + /// Updates the [`Sentence`] using a given string. + /// + /// # Arguments + /// + /// * `raw_text` - A raw string without any annotation. + /// + /// # Errors + /// + /// If the given `raw_text` is empty, an error variant will be returned. + /// When an error is occurred, the sentence will be replaced with a white space. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let mut s = Sentence::from_raw("How are you?").unwrap(); + /// s.update_raw("I am file.").unwrap(); + /// assert_eq!("I am file.", s.to_raw_string()); + /// ``` + pub fn update_raw(&mut self, raw_text: S) -> Result<()> + where + S: Into, + { + let raw_text = raw_text.into(); + + match Self::parse_raw_text( + &raw_text, + &mut self.chars, + &mut self.boundaries, + &mut self.tags, + ) { + Ok(_) => { + self.text = raw_text; + self.update_common_info(); + Ok(()) + } + Err(e) => { + self.clear(); + Err(e) + } + } } /// Gets a string without any annotation. @@ -180,7 +575,10 @@ impl Sentence { /// /// # Arguments /// - /// * `tokenized_text` - A tokenized string containing whitespaces for word boundaries. + /// * `tokenized_text` - A tokenized text that is annotated by the following rules: + /// - A whitespace (`' '`) is inserted to each token boundary. + /// - If necessary, a POS tag following a slash (`'/'`) can be added to each token. + /// - Each character following a back slash (`'\\'`) is escaped. /// /// # Returns /// @@ -202,6 +600,9 @@ impl Sentence { /// let s = Sentence::from_tokenized("How are you?"); /// assert!(s.is_ok()); /// + /// let s = Sentence::from_tokenized("How/WRB are/VBP you?"); + /// assert!(s.is_ok()); + /// /// let s = Sentence::from_tokenized("How are you?"); /// assert!(s.is_err()); /// ``` @@ -211,56 +612,76 @@ impl Sentence { { let tokenized_text = tokenized_text.as_ref(); - if tokenized_text.is_empty() { - return Err(anyhow!("`tokenized_text` is empty")); - } + let mut text = String::with_capacity(0); + let mut chars = Vec::with_capacity(0); + let mut boundaries = Vec::with_capacity(0); + let mut tags = Vec::with_capacity(0); - let tokenized_chars: Vec = tokenized_text.chars().collect(); - let mut chars = Vec::with_capacity(tokenized_chars.len()); - let mut boundaries = Vec::with_capacity(tokenized_chars.len() - 1); + Self::parse_tokenized_text( + tokenized_text, + &mut text, + &mut chars, + &mut boundaries, + &mut tags, + )?; - let mut prev_boundary = false; - let mut escape = false; - for c in tokenized_chars { - match (escape, c) { - (false, '\\') => { - escape = true; - } - (false, ' ') => { - if chars.is_empty() { - return Err(anyhow!("`tokenized_text` starts with a whitespace")); - } else if prev_boundary { - return Err(anyhow!("`tokenized_text` contains consecutive whitespaces")); - } - prev_boundary = true; - } - (_, _) => { - if !chars.is_empty() { - boundaries.push(if prev_boundary { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }); - } - prev_boundary = false; - escape = false; - chars.push(c); - } - }; - } - if prev_boundary { - return Err(anyhow!("`tokenized_text` ends with a whitespace")); - } + Ok(Self::internal_new(text, chars, boundaries, tags)) + } - let (char_to_str_pos, str_to_char_pos, char_type) = Self::common_info(&chars); - Ok(Self { - text: chars.iter().collect(), - char_to_str_pos, - str_to_char_pos, - char_type, - boundaries, - boundary_scores: None, - }) + /// Updates the [`Sentence`] using tokenized string. + /// + /// # Arguments + /// + /// * `tokenized_text` - A tokenized text that is annotated by the following rules: + /// - A whitespace (`' '`) is inserted to each token boundary. + /// - If necessary, a POS tag following a slash (`'/'`) can be added to each token. + /// - Each character following a back slash (`'\\'`) is escaped. + /// + /// # Errors + /// + /// This function will return an error variant when: + /// + /// * `tokenized_text` is empty. + /// * `tokenized_text` starts/ends with a whitespace. + /// * `tokenized_text` contains consecutive whitespaces. + /// + /// When an error is occurred, the sentence will be replaced with a white space. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let mut s = Sentence::from_tokenized("How are you?").unwrap(); + /// + /// s.update_tokenized("I am fine").unwrap(); + /// assert_eq!("Iamfine", s.to_raw_string()); + /// + /// s.update_tokenized("How/WRB are/VBP you ?/.").unwrap(); + /// assert_eq!("Howareyou?", s.to_raw_string()); + /// ``` + pub fn update_tokenized(&mut self, tokenized_text: S) -> Result<()> + where + S: AsRef, + { + let tokenized_text = tokenized_text.as_ref(); + + match Self::parse_tokenized_text( + tokenized_text, + &mut self.text, + &mut self.chars, + &mut self.boundaries, + &mut self.tags, + ) { + Ok(_) => { + self.update_common_info(); + Ok(()) + } + Err(e) => { + self.clear(); + Err(e) + } + } } /// Generates a string with whitespaces for word boundaries. @@ -280,6 +701,9 @@ impl Sentence { /// /// let s = Sentence::from_tokenized("How are you?").unwrap(); /// assert_eq!("How are you?", s.to_tokenized_string().unwrap()); + /// + /// let s = Sentence::from_tokenized("How/WRB are/VBP you?").unwrap(); + /// assert_eq!("How/WRB are/VBP you?", s.to_tokenized_string().unwrap()); /// ``` pub fn to_tokenized_string(&self) -> Result { let chars: Vec = self.text.chars().collect(); @@ -289,14 +713,22 @@ impl Sentence { _ => (), } result.push(chars[0]); - for (&c, b) in chars[1..].iter().zip(&self.boundaries) { + for (i, (&c, b)) in chars[1..].iter().zip(&self.boundaries).enumerate() { match b { BoundaryType::WordBoundary => { + if !self.tags.is_empty() { + if let Some(tag) = self.tags.get(i).and_then(|x| x.as_ref()) { + result.push('/'); + result.push_str(tag); + } + } result.push(' '); } BoundaryType::NotWordBoundary => (), BoundaryType::Unknown => { - return Err(anyhow!("sentence contains an unknown boundary")); + return Err(VaporettoError::invalid_sentence( + "contains an unknown boundary", + )); } } match c { @@ -305,14 +737,18 @@ impl Sentence { } result.push(c); } + if let Some(tag) = self.tags.last().and_then(|x| x.as_ref()) { + result.push('/'); + result.push_str(tag); + } Ok(result) } - /// Generates a vector of words. + /// Generates a vector of tokens. /// /// # Returns /// - /// A newly allocated vector of words. + /// A newly allocated vector of tokens. /// /// # Errors /// @@ -321,35 +757,72 @@ impl Sentence { /// # Examples /// /// ``` - /// use vaporetto::Sentence; + /// use vaporetto::{Sentence, Token}; /// /// let s = Sentence::from_tokenized("How are you ?").unwrap(); /// assert_eq!(vec![ - /// "How", - /// "are", - /// "you", - /// "?", + /// Token { surface: "How", tag: None }, + /// Token { surface: "are", tag: None }, + /// Token { surface: "you", tag: None }, + /// Token { surface: "?", tag: None }, + /// ], s.to_tokenized_vec().unwrap()); + /// + /// let s = Sentence::from_tokenized("How/WRB are/VBP you/PRP ?/.").unwrap(); + /// assert_eq!(vec![ + /// Token { surface: "How", tag: Some("WRB") }, + /// Token { surface: "are", tag: Some("VBP") }, + /// Token { surface: "you", tag: Some("PRP") }, + /// Token { surface: "?", tag: Some(".") }, /// ], s.to_tokenized_vec().unwrap()); /// ``` - pub fn to_tokenized_vec(&self) -> Result> { + pub fn to_tokenized_vec(&self) -> Result> { let mut result = vec![]; let mut start = 0; - for (i, b) in self.boundaries.iter().enumerate() { - match b { - BoundaryType::WordBoundary => { - let end = unsafe { *self.char_to_str_pos.get_unchecked(i + 1) }; - let word = unsafe { self.text.get_unchecked(start..end) }; - result.push(word); - start = end; + if self.tags.is_empty() { + for (i, b) in self.boundaries.iter().enumerate() { + match b { + BoundaryType::WordBoundary => { + let end = unsafe { *self.char_to_str_pos.get_unchecked(i + 1) }; + let surface = unsafe { self.text.get_unchecked(start..end) }; + result.push(Token { surface, tag: None }); + start = end; + } + BoundaryType::NotWordBoundary => (), + BoundaryType::Unknown => { + return Err(VaporettoError::invalid_sentence( + "contains an unknown boundary", + )); + } } - BoundaryType::NotWordBoundary => (), - BoundaryType::Unknown => { - return Err(anyhow!("sentence contains an unknown boundary")); + } + let surface = unsafe { self.text.get_unchecked(start..) }; + result.push(Token { surface, tag: None }); + } else { + for (i, (b, tag)) in self.boundaries.iter().zip(&self.tags).enumerate() { + match b { + BoundaryType::WordBoundary => { + let end = unsafe { *self.char_to_str_pos.get_unchecked(i + 1) }; + let surface = unsafe { self.text.get_unchecked(start..end) }; + let tag = tag.as_ref().map(|x| x.as_str()); + result.push(Token { surface, tag }); + start = end; + } + BoundaryType::NotWordBoundary => (), + BoundaryType::Unknown => { + return Err(VaporettoError::invalid_sentence( + "contains an unknown boundary", + )); + } } } + let surface = unsafe { self.text.get_unchecked(start..) }; + let tag = self + .tags + .last() + .and_then(|x| x.as_ref()) + .map(|x| x.as_str()); + result.push(Token { surface, tag }); } - let word = unsafe { self.text.get_unchecked(start..) }; - result.push(word); Ok(result) } @@ -357,7 +830,12 @@ impl Sentence { /// /// # Arguments /// - /// * `labeled_text` - A string with partial annotations. + /// * `labeled_text` - A partially annotated text. Each character boundary is annotated by the following rules: + /// - If the boundary is a token boundary, a pipe symbol (`'|'`) is inserted. + /// - If the boundary is not a token boundary, a dash symobl (`'-'`) is inserted. + /// - If the boundary is not annotated, a whitespace (`' '`) is inserted. + /// + /// In addition, a POS tag following a slash (`'/'`) can be inserted to each token. /// /// # Returns /// @@ -379,6 +857,9 @@ impl Sentence { /// let s = Sentence::from_partial_annotation("g-o-o-d|i-d e-a"); /// assert!(s.is_ok()); /// + /// let s = Sentence::from_partial_annotation("I-t/PRP|'-s/VBZ|o-k-a-y/JJ|./."); + /// assert!(s.is_ok()); + /// /// let s = Sentence::from_partial_annotation("b-a-d/i-d-e-a"); /// assert!(s.is_err()); /// ``` @@ -388,69 +869,121 @@ impl Sentence { { let labeled_text = labeled_text.as_ref(); - if labeled_text.is_empty() { - return Err(anyhow!("`labeled_text` is empty")); - } - - let labeled_chars: Vec = labeled_text.chars().collect(); - if labeled_chars.len() & 0x01 == 0 { - return Err(anyhow!( - "invalid length for `labeled_text`: {}", - labeled_chars.len() - )); - } - let mut chars = Vec::with_capacity(labeled_chars.len() / 2 + 1); - let mut boundaries = Vec::with_capacity(labeled_chars.len() / 2); - - for c in labeled_chars.iter().skip(1).step_by(2) { - boundaries.push(match c { - ' ' => BoundaryType::Unknown, - '|' => BoundaryType::WordBoundary, - '-' => BoundaryType::NotWordBoundary, - _ => return Err(anyhow!("invalid boundary character: '{}'", c)), - }); - } - for c in labeled_chars.into_iter().step_by(2) { - chars.push(c); - } - - let (char_to_str_pos, str_to_char_pos, char_type) = Self::common_info(&chars); - Ok(Self { - text: chars.iter().collect(), - char_to_str_pos, - str_to_char_pos, - char_type, - boundaries, - boundary_scores: None, - }) + let mut text = String::with_capacity(0); + let mut chars = Vec::with_capacity(0); + let mut boundaries = Vec::with_capacity(0); + let mut tags = Vec::with_capacity(0); + Self::parse_partial_annotation( + labeled_text, + &mut text, + &mut chars, + &mut boundaries, + &mut tags, + )?; + + Ok(Self::internal_new(text, chars, boundaries, tags)) } - /// Generates a string with partial annotations. + /// Updates the [`Sentence`] using a string with partial annotations. /// - /// # Returns + /// # Arguments /// - /// A newly allocated string with partial annotations. + /// * `labeled_text` - A partially annotated text. Each character boundary is annotated by the following rules: + /// - If the boundary is a token boundary, a pipe symbol (`'|'`) is inserted. + /// - If the boundary is not a token boundary, a dash symobl (`'-'`) is inserted. + /// - If the boundary is not annotated, a whitespace (`' '`) is inserted. + /// + /// In addition, a POS tag following a slash (`'/'`) can be inserted to each token. + /// + /// # Errors + /// + /// This function will return an error variant when: + /// + /// * `labeled_text` is empty. + /// * The length of `lsbeled_text` is even numbers. + /// * `labeled_text` contains invalid boundary characters. + /// + /// When an error is occurred, the sentence will be replaced with a white space. /// /// # Examples /// /// ``` /// use vaporetto::Sentence; /// - /// let s = Sentence::from_tokenized("How are you ?").unwrap(); - /// assert_eq!("H-o-w|a-r-e|y-o-u|?", &s.to_partial_annotation_string()); + /// let mut s = Sentence::from_raw("g-o-o-d|i-d e-a").unwrap(); + /// s.update_partial_annotation("h-e-l-l-o").unwrap(); + /// assert_eq!("hello", s.to_raw_string()); + /// + /// s.update_partial_annotation("I-t/PRP|'-s/VBZ|o-k-a-y/JJ|./.").unwrap(); + /// assert_eq!("It'sokay.", s.to_raw_string()); /// ``` - pub fn to_partial_annotation_string(&self) -> String { - let chars: Vec = self.text.chars().collect(); - let mut result = String::with_capacity(self.text.len() + chars.len() - 1); - result.push(chars[0]); - for (&c, b) in chars[1..].iter().zip(&self.boundaries) { - result.push(match b { - BoundaryType::WordBoundary => '|', - BoundaryType::NotWordBoundary => '-', - BoundaryType::Unknown => ' ', - }); - result.push(c); - } + pub fn update_partial_annotation(&mut self, labeled_text: S) -> Result<()> + where + S: AsRef, + { + let labeled_text = labeled_text.as_ref(); + + match Self::parse_partial_annotation( + labeled_text, + &mut self.text, + &mut self.chars, + &mut self.boundaries, + &mut self.tags, + ) { + Ok(_) => { + self.update_common_info(); + Ok(()) + } + Err(e) => { + self.clear(); + Err(e) + } + } + } + + /// Generates a string with partial annotations. + /// + /// # Returns + /// + /// A newly allocated string with partial annotations. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let s = Sentence::from_tokenized("How are you ?").unwrap(); + /// assert_eq!("H-o-w|a-r-e|y-o-u|?", &s.to_partial_annotation_string()); + /// + /// let s = Sentence::from_tokenized("How/WRB are you/PRP ?").unwrap(); + /// assert_eq!("H-o-w/WRB|a-r-e|y-o-u/PRP|?", &s.to_partial_annotation_string()); + /// ``` + pub fn to_partial_annotation_string(&self) -> String { + let chars: Vec = self.text.chars().collect(); + let mut result = String::with_capacity(self.text.len() + chars.len() - 1); + result.push(chars[0]); + for (i, (&c, b)) in chars[1..].iter().zip(&self.boundaries).enumerate() { + match b { + BoundaryType::WordBoundary => { + if let Some(tag) = self.tags.get(i).and_then(|x| x.as_ref()) { + result.push('/'); + result.push_str(tag); + } + result.push('|'); + } + BoundaryType::NotWordBoundary => { + result.push('-'); + } + BoundaryType::Unknown => { + result.push(' '); + } + } + result.push(c); + } + if let Some(tag) = self.tags.last().and_then(|x| x.as_ref()) { + result.push('/'); + result.push_str(tag); + } result } @@ -485,6 +1018,94 @@ impl Sentence { &mut self.boundaries } + /// Gets a reference to the part-of-speech information. + /// + /// Each tag is placed at the last of the corresponding token. For example, when the first token + /// containing three characters has a tag, that tag will be placed at the third element of the + /// returned slice. + /// + /// # Returns + /// + /// A reference to the POS information. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// + /// use vaporetto::{BoundaryType, Sentence}; + /// + /// let s = Sentence::from_tokenized("I/PRP am a/DT cat/NN ./.").unwrap(); + /// assert_eq!(&[ + /// Some(Arc::new("PRP".to_string())), // 'I' + /// None, // 'a' + /// None, // 'm' + /// Some(Arc::new("DT".to_string())), // 'a' + /// None, // 'c' + /// None, // 'a' + /// Some(Arc::new("NN".to_string())), // 't' + /// Some(Arc::new(".".to_string())), // '.' + /// ], s.tags()); + /// ``` + pub fn tags(&self) -> &[Option>] { + &self.tags + } + + /// Gets a mutable reference to the part-of-speech information. + /// + /// # Returns + /// + /// A mutable reference to the part-of-speech information. + pub fn tags_mut(&mut self) -> &mut [Option>] { + &mut self.tags + } + + /// Gets a reference to the characters. + /// + /// # Returns + /// + /// A reference to the characters. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let s = Sentence::from_raw("A1あエ漢?").unwrap(); + /// assert_eq!(&['A', '1', 'あ', 'エ', '漢', '?'], s.chars()); + /// ``` + pub fn chars(&self) -> &[char] { + &self.chars + } + + /// Gets immutable references to the characters and character types, and a mutable reference to + /// boundaries. + /// + /// # Returns + /// + /// A tuple of references. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::{BoundaryType, Sentence}; + /// + /// let mut s = Sentence::from_partial_annotation("A-1|あ エ-漢|?").unwrap(); + /// let (chars, char_types, boundaries) = s.chars_and_boundaries_mut(); + /// assert_eq!(&['A', '1', 'あ', 'エ', '漢', '?'], chars); + /// assert_eq!(&[b'R', b'D', b'H', b'T', b'K', b'O'], char_types); + /// assert_eq!(&[ + /// BoundaryType::NotWordBoundary, + /// BoundaryType::WordBoundary, + /// BoundaryType::Unknown, + /// BoundaryType::NotWordBoundary, + /// BoundaryType::WordBoundary, + /// ], boundaries); + /// ``` + pub fn chars_and_boundaries_mut(&mut self) -> (&[char], &[u8], &mut [BoundaryType]) { + (&self.chars, &self.char_type, &mut self.boundaries) + } + /// Gets a reference to the character type information. /// /// # Returns @@ -497,7 +1118,7 @@ impl Sentence { /// use vaporetto::Sentence; /// /// let s = Sentence::from_raw("A1あエ漢?").unwrap(); - /// assert_eq!(&[b'R', b'D', b'H', b'T', b'K', b'O',], s.char_types()); + /// assert_eq!(&[b'R', b'D', b'H', b'T', b'K', b'O'], s.char_types()); /// ``` pub fn char_types(&self) -> &[u8] { &self.char_type @@ -508,8 +1129,8 @@ impl Sentence { /// # Returns /// /// If the predictor inserted, the boundary score information is returned. Otherwise, None. - pub fn boundary_scores(&self) -> Option<&[f64]> { - self.boundary_scores.as_deref() + pub fn boundary_scores(&self) -> &[i32] { + &self.boundary_scores } /// Gets a character position in the code point unit. @@ -527,7 +1148,7 @@ impl Sentence { } else { match self.str_to_char_pos.get(index) { Some(index) if *index != 0 => Ok(*index), - _ => Err(anyhow!("invalid index")), + _ => Err(VaporettoError::invalid_argument("index", "invalid index")), } } } @@ -538,25 +1159,79 @@ impl Sentence { let end = self.char_to_str_pos[end]; &self.text.as_str()[begin..end] } - - #[cfg(feature = "train")] - pub(crate) fn type_substring(&self, start: usize, end: usize) -> &[u8] { - &self.char_type[start..end] - } } #[cfg(test)] mod tests { use super::*; use BoundaryType::*; - use CharacterType::*; #[test] fn test_sentence_from_raw_empty() { let s = Sentence::from_raw(""); - assert!(s.is_err()); - assert_eq!("`text` is empty", &s.err().unwrap().to_string()); + assert_eq!( + "InvalidArgumentError: raw_text: must contain at least one character", + &s.err().unwrap().to_string() + ); + } + + #[test] + fn test_sentence_update_raw_empty() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_raw(""); + + assert_eq!( + "InvalidArgumentError: raw_text: must contain at least one character", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: b"O".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + + #[test] + fn test_sentence_from_raw_null() { + let s = Sentence::from_raw("A1あ\0ア亜"); + + assert_eq!( + "InvalidArgumentError: raw_text: must not contain NULL", + &s.err().unwrap().to_string() + ); + } + + #[test] + fn test_sentence_update_raw_null() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_raw("A1あ\0ア亜"); + + assert_eq!( + "InvalidArgumentError: raw_text: must not contain NULL", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: b"O".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); } #[test] @@ -565,21 +1240,47 @@ mod tests { let expected = Sentence { text: "あ".to_string(), + chars: vec!['あ'], str_to_char_pos: vec![0, 0, 0, 1], char_to_str_pos: vec![0, 3], - char_type: ct2u8vec![Hiragana], + char_type: b"H".to_vec(), boundaries: vec![], - boundary_scores: None, + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], }; assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_raw_one() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_raw("あ").unwrap(); + + let expected = Sentence { + text: "あ".to_string(), + chars: vec!['あ'], + str_to_char_pos: vec![0, 0, 0, 1], + char_to_str_pos: vec![0, 3], + char_type: b"H".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_raw() { let s = Sentence::from_raw("Rustで良いプログラミング体験を!"); let expected = Sentence { text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], str_to_char_pos: vec![ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, @@ -587,16 +1288,42 @@ mod tests { char_to_str_pos: vec![ 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, ], - char_type: ct2u8vec![ - Roman, Roman, Roman, Roman, Hiragana, Kanji, Hiragana, Katakana, Katakana, - Katakana, Katakana, Katakana, Katakana, Katakana, Kanji, Kanji, Hiragana, Other, - ], + char_type: b"RRRRHKHTTTTTTTKKHO".to_vec(), boundaries: vec![Unknown; 17], - boundary_scores: None, + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 18], }; assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_raw() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_raw("Rustで良いプログラミング体験を!").unwrap(); + + let expected = Sentence { + text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], + str_to_char_pos: vec![ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + char_to_str_pos: vec![ + 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, + ], + char_type: b"RRRRHKHTTTTTTTKKHO".to_vec(), + boundaries: vec![Unknown; 17], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 18], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_to_raw() { let s = Sentence::from_raw("Rustで良いプログラミング体験を!"); @@ -611,64 +1338,219 @@ mod tests { fn test_sentence_from_tokenized_empty() { let s = Sentence::from_tokenized(""); - assert!(s.is_err()); - assert_eq!("`tokenized_text` is empty", &s.err().unwrap().to_string()); + assert_eq!( + "InvalidArgumentError: tokenized_text: must contain at least one character", + &s.err().unwrap().to_string() + ); + } + + #[test] + fn test_sentence_update_tokenized_empty() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized(""); + + assert_eq!( + "InvalidArgumentError: tokenized_text: must contain at least one character", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: b"O".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + + #[test] + fn test_sentence_from_tokenized_null() { + let s = Sentence::from_tokenized("A1あ\0ア亜"); + + assert_eq!( + "InvalidArgumentError: tokenized_text: must not contain NULL", + &s.err().unwrap().to_string() + ); + } + + #[test] + fn test_sentence_update_tokenized_null() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized("A1あ\0ア亜"); + + assert_eq!( + "InvalidArgumentError: tokenized_text: must not contain NULL", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: b"O".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); } #[test] fn test_sentence_from_tokenized_start_with_space() { let s = Sentence::from_tokenized(" Rust で 良い プログラミング 体験 を !"); - assert!(s.is_err()); assert_eq!( - "`tokenized_text` starts with a whitespace", + "InvalidArgumentError: tokenized_text: must not start with a whitespace", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_tokenized_start_with_space() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized(" Rust で 良い プログラミング 体験 を !"); + + assert_eq!( + "InvalidArgumentError: tokenized_text: must not start with a whitespace", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: b"O".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_end_with_space() { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を ! "); - assert!(s.is_err()); assert_eq!( - "`tokenized_text` ends with a whitespace", + "InvalidArgumentError: tokenized_text: must not end with a whitespace", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_tokenized_end_with_space() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized("Rust で 良い プログラミング 体験 を ! "); + + assert_eq!( + "InvalidArgumentError: tokenized_text: must not end with a whitespace", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: b"O".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_two_spaces() { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !"); - assert!(s.is_err()); assert_eq!( - "`tokenized_text` contains consecutive whitespaces", + "InvalidArgumentError: tokenized_text: must not contain consecutive whitespaces", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_tokenized_two_spaces() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized("Rust で 良い プログラミング 体験 を !"); + + assert_eq!( + "InvalidArgumentError: tokenized_text: must not contain consecutive whitespaces", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: b"O".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_one() { let s = Sentence::from_tokenized("あ"); let expected = Sentence { text: "あ".to_string(), + chars: vec!['あ'], str_to_char_pos: vec![0, 0, 0, 1], char_to_str_pos: vec![0, 3], - char_type: ct2u8vec![Hiragana], + char_type: b"H".to_vec(), boundaries: vec![], - boundary_scores: None, + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], }; assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_tokenized_one() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("あ").unwrap(); + + let expected = Sentence { + text: "あ".to_string(), + chars: vec!['あ'], + str_to_char_pos: vec![0, 0, 0, 1], + char_to_str_pos: vec![0, 3], + char_type: b"H".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized() { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !"); let expected = Sentence { text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], str_to_char_pos: vec![ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, @@ -676,10 +1558,52 @@ mod tests { char_to_str_pos: vec![ 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, ], - char_type: ct2u8vec![ - Roman, Roman, Roman, Roman, Hiragana, Kanji, Hiragana, Katakana, Katakana, - Katakana, Katakana, Katakana, Katakana, Katakana, Kanji, Kanji, Hiragana, Other, + char_type: b"RRRRHKHTTTTTTTKKHO".to_vec(), + boundaries: vec![ + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + ], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 18], + }; + assert_eq!(expected, s.unwrap()); + } + + #[test] + fn test_sentence_from_tokenized_with_tags() { + let s = + Sentence::from_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号"); + + let expected = Sentence { + text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], + str_to_char_pos: vec![ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + char_to_str_pos: vec![ + 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, ], + char_type: b"RRRRHKHTTTTTTTKKHO".to_vec(), boundaries: vec![ NotWordBoundary, NotWordBoundary, @@ -699,17 +1623,153 @@ mod tests { WordBoundary, WordBoundary, ], - boundary_scores: None, + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![ + None, + None, + None, + Some(Arc::new("名詞".to_string())), + None, + None, + Some(Arc::new("形容詞".to_string())), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + Some(Arc::new("補助記号".to_string())), + ], }; assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_tokenized() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("Rust で 良い プログラミング 体験 を !") + .unwrap(); + + let expected = Sentence { + text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], + str_to_char_pos: vec![ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + char_to_str_pos: vec![ + 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, + ], + char_type: b"RRRRHKHTTTTTTTKKHO".to_vec(), + boundaries: vec![ + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + ], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 18], + }; + assert_eq!(expected, s); + } + + #[test] + fn test_sentence_update_tokenized_with_tags() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号") + .unwrap(); + + let expected = Sentence { + text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], + str_to_char_pos: vec![ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + char_to_str_pos: vec![ + 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, + ], + char_type: b"RRRRHKHTTTTTTTKKHO".to_vec(), + boundaries: vec![ + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + ], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![ + None, + None, + None, + Some(Arc::new("名詞".to_string())), + None, + None, + Some(Arc::new("形容詞".to_string())), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + Some(Arc::new("補助記号".to_string())), + ], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_with_escape_whitespace() { - let s = Sentence::from_tokenized("火星 猫 の 生態 ( M \\ et\\ al. )"); + let s = Sentence::from_tokenized("火星 猫 の 生態 ( M \\ et\\ al. )").unwrap(); let expected = Sentence { text: "火星猫の生態(M et al.)".to_string(), + chars: vec![ + '火', '星', '猫', 'の', '生', '態', '(', 'M', ' ', 'e', 't', ' ', 'a', 'l', '.', + ')', + ], str_to_char_pos: vec![ 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -717,10 +1777,51 @@ mod tests { char_to_str_pos: vec![ 0, 3, 6, 9, 12, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, ], - char_type: ct2u8vec![ - Kanji, Kanji, Kanji, Hiragana, Kanji, Kanji, Other, Roman, Other, Roman, Roman, - Other, Roman, Roman, Other, Other, + char_type: b"KKKHKKORORRORROO".to_vec(), + boundaries: vec![ + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + ], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 16], + }; + assert_eq!(expected, s); + } + + #[test] + fn test_sentence_update_tokenized_escape_whitespace() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("火星 猫 の 生態 ( M \\ et\\ al. )") + .unwrap(); + + let expected = Sentence { + text: "火星猫の生態(M et al.)".to_string(), + chars: vec![ + '火', '星', '猫', 'の', '生', '態', '(', 'M', ' ', 'e', 't', ' ', 'a', 'l', '.', + ')', ], + str_to_char_pos: vec![ + 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, + ], + char_to_str_pos: vec![ + 0, 3, 6, 9, 12, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + ], + char_type: b"KKKHKKORORRORROO".to_vec(), boundaries: vec![ NotWordBoundary, WordBoundary, @@ -738,9 +1839,11 @@ mod tests { NotWordBoundary, WordBoundary, ], - boundary_scores: None, + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 16], }; - assert_eq!(expected, s.unwrap()); + assert_eq!(expected, s); } #[test] @@ -749,13 +1852,42 @@ mod tests { let expected = Sentence { text: "改行に\\nを用いる".to_string(), + chars: vec!['改', '行', 'に', '\\', 'n', 'を', '用', 'い', 'る'], str_to_char_pos: vec![ 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, ], char_to_str_pos: vec![0, 3, 6, 9, 10, 11, 14, 17, 20, 23], - char_type: ct2u8vec![ - Kanji, Kanji, Hiragana, Other, Roman, Hiragana, Kanji, Hiragana, Hiragana, + char_type: b"KKHORHKHH".to_vec(), + boundaries: vec![ + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, ], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 9], + }; + assert_eq!(expected, s.unwrap()); + } + + #[test] + fn test_sentence_update_tokenized_with_escape_backslash() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("改行 に \\\\n を 用い る").unwrap(); + + let expected = Sentence { + text: "改行に\\nを用いる".to_string(), + chars: vec!['改', '行', 'に', '\\', 'n', 'を', '用', 'い', 'る'], + str_to_char_pos: vec![ + 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, + ], + char_to_str_pos: vec![0, 3, 6, 9, 10, 11, 14, 17, 20, 23], + char_type: b"KKHORHKHH".to_vec(), boundaries: vec![ NotWordBoundary, WordBoundary, @@ -766,19 +1898,77 @@ mod tests { NotWordBoundary, WordBoundary, ], - boundary_scores: None, + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 9], + }; + assert_eq!(expected, s); + } + + #[test] + fn test_sentence_from_tokenized_escape_slash() { + let s = Sentence::from_tokenized("品詞 に \\/ を 用い る"); + + let expected = Sentence { + text: "品詞に/を用いる".to_string(), + chars: vec!['品', '詞', 'に', '/', 'を', '用', 'い', 'る'], + str_to_char_pos: vec![ + 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, + ], + char_to_str_pos: vec![0, 3, 6, 9, 10, 13, 16, 19, 22], + char_type: b"KKHOHKHH".to_vec(), + boundaries: vec![ + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + ], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 8], }; assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_tokenized_escape_slash() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("品詞 に \\/ を 用い る").unwrap(); + + let expected = Sentence { + text: "品詞に/を用いる".to_string(), + chars: vec!['品', '詞', 'に', '/', 'を', '用', 'い', 'る'], + str_to_char_pos: vec![ + 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, + ], + char_to_str_pos: vec![0, 3, 6, 9, 10, 13, 16, 19, 22], + char_type: b"KKHOHKHH".to_vec(), + boundaries: vec![ + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + ], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 8], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_to_tokenized_string_unknown() { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); let result = s.unwrap().to_tokenized_string(); - assert!(result.is_err()); assert_eq!( - "sentence contains an unknown boundary", + "InvalidSentenceError: contains an unknown boundary", result.err().unwrap().to_string() ); } @@ -793,6 +1983,17 @@ mod tests { ); } + #[test] + fn test_sentence_to_tokenized_string_with_tags() { + let s = + Sentence::from_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号"); + + assert_eq!( + "Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号", + s.unwrap().to_tokenized_string().unwrap() + ); + } + #[test] fn test_sentence_to_tokenized_string_escape() { let s = Sentence::from_partial_annotation("火-星-猫|の| |生-態|\\-n"); @@ -808,9 +2009,8 @@ mod tests { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態").unwrap(); let result = s.to_tokenized_vec(); - assert!(result.is_err()); assert_eq!( - "sentence contains an unknown boundary", + "InvalidSentenceError: contains an unknown boundary", result.err().unwrap().to_string() ); } @@ -820,7 +2020,77 @@ mod tests { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !").unwrap(); assert_eq!( - vec!["Rust", "で", "良い", "プログラミング", "体験", "を", "!"], + vec![ + Token { + surface: "Rust", + tag: None + }, + Token { + surface: "で", + tag: None + }, + Token { + surface: "良い", + tag: None + }, + Token { + surface: "プログラミング", + tag: None + }, + Token { + surface: "体験", + tag: None + }, + Token { + surface: "を", + tag: None + }, + Token { + surface: "!", + tag: None + }, + ], + s.to_tokenized_vec().unwrap() + ); + } + + #[test] + fn test_sentence_to_tokenized_vec_with_tags() { + let s = + Sentence::from_tokenized("Rust/名詞 で 良い/形容詞 プログラミング 体験 を !/補助記号") + .unwrap(); + + assert_eq!( + vec![ + Token { + surface: "Rust", + tag: Some("名詞"), + }, + Token { + surface: "で", + tag: None, + }, + Token { + surface: "良い", + tag: Some("形容詞"), + }, + Token { + surface: "プログラミング", + tag: None, + }, + Token { + surface: "体験", + tag: None, + }, + Token { + surface: "を", + tag: None, + }, + Token { + surface: "!", + tag: Some("補助記号"), + }, + ], s.to_tokenized_vec().unwrap() ); } @@ -829,41 +2099,109 @@ mod tests { fn test_sentence_from_partial_annotation_empty() { let s = Sentence::from_partial_annotation(""); - assert!(s.is_err()); - assert_eq!("`labeled_text` is empty", &s.err().unwrap().to_string()); + assert_eq!( + "InvalidArgumentError: labeled_text: must contain at least one character", + &s.err().unwrap().to_string() + ); } #[test] - fn test_sentence_from_partial_annotation_invalid_length() { - let s = Sentence::from_partial_annotation("火-星 猫|の|生-態 "); + fn test_sentence_update_partial_annotation_empty() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_partial_annotation(""); + + assert_eq!( + "InvalidArgumentError: labeled_text: must contain at least one character", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: b"O".to_vec(), + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + + #[test] + fn test_sentence_from_partial_annotation_null() { + let s = Sentence::from_partial_annotation("A-1-あ-\0-ア-亜"); - assert!(s.is_err()); assert_eq!( - "invalid length for `labeled_text`: 12", + "InvalidArgumentError: labeled_text: must not contain NULL", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_partial_annotation_null() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_partial_annotation("A-1-あ-\0-ア-亜"); + + assert_eq!( + "InvalidArgumentError: labeled_text: must not contain NULL", + &result.err().unwrap().to_string() + ); + } + + #[test] + fn test_sentence_from_partial_annotation_invalid_length() { + let result = Sentence::from_partial_annotation("火-星 猫|の|生-態 "); + + assert_eq!( + "InvalidArgumentError: labeled_text: invalid annotation", + &result.err().unwrap().to_string() + ); + } + + #[test] + fn test_sentence_update_partial_annotation_invalid_length() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_partial_annotation("火-星 猫|の|生-態 "); + + assert_eq!( + "InvalidArgumentError: labeled_text: invalid annotation", + &result.err().unwrap().to_string() + ); + } + #[test] fn test_sentence_from_partial_annotation_invalid_boundary_character() { let s = Sentence::from_partial_annotation("火-星?猫|の|生-態"); - assert!(s.is_err()); assert_eq!( - "invalid boundary character: '?'", + "InvalidArgumentError: labeled_text: contains an invalid boundary character: '?'", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_partial_annotation_invalid_boundary_character() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_partial_annotation("火-星?猫|の|生-態"); + + assert_eq!( + "InvalidArgumentError: labeled_text: contains an invalid boundary character: '?'", + &result.err().unwrap().to_string() + ); + } + #[test] fn test_sentence_from_partial_annotation_one() { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); let expected = Sentence { text: "火星猫の生態".to_string(), + chars: vec!['火', '星', '猫', 'の', '生', '態'], str_to_char_pos: vec![0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6], char_to_str_pos: vec![0, 3, 6, 9, 12, 15, 18], - char_type: ct2u8vec![Kanji, Kanji, Kanji, Hiragana, Kanji, Kanji], + char_type: b"KKKHKK".to_vec(), boundaries: vec![ NotWordBoundary, Unknown, @@ -871,11 +2209,38 @@ mod tests { WordBoundary, NotWordBoundary, ], - boundary_scores: None, + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 6], }; assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_partial_annotation_one() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_partial_annotation("火-星 猫|の|生-態").unwrap(); + + let expected = Sentence { + text: "火星猫の生態".to_string(), + chars: vec!['火', '星', '猫', 'の', '生', '態'], + str_to_char_pos: vec![0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6], + char_to_str_pos: vec![0, 3, 6, 9, 12, 15, 18], + char_type: b"KKKHKK".to_vec(), + boundaries: vec![ + NotWordBoundary, + Unknown, + WordBoundary, + WordBoundary, + NotWordBoundary, + ], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None; 6], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_to_partial_annotation_string() { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); @@ -885,4 +2250,14 @@ mod tests { s.unwrap().to_partial_annotation_string() ); } + + #[test] + fn test_sentence_to_partial_annotation_string_with_tags() { + let s = Sentence::from_partial_annotation("火-星 猫|の/助詞|生-態/名詞"); + + assert_eq!( + "火-星 猫|の/助詞|生-態/名詞", + s.unwrap().to_partial_annotation_string() + ); + } } diff --git a/vaporetto/src/tag_model.rs b/vaporetto/src/tag_model.rs new file mode 100644 index 00000000..205ccfb7 --- /dev/null +++ b/vaporetto/src/tag_model.rs @@ -0,0 +1,85 @@ +use std::io::{Read, Write}; + +use crate::errors::Result; +use crate::ngram_model::NgramModel; +use crate::utils; + +pub struct TagClassInfo { + pub(crate) name: String, + pub(crate) bias: i32, +} + +impl TagClassInfo { + pub fn serialize(&self, mut wtr: W) -> Result<()> + where + W: Write, + { + utils::write_u32(&mut wtr, self.name.len().try_into().unwrap())?; + wtr.write_all(self.name.as_bytes())?; + utils::write_i32(&mut wtr, self.bias)?; + Ok(()) + } + + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + let name_size = utils::read_u32(&mut rdr)?; + let mut name_bytes = vec![0; name_size.try_into().unwrap()]; + rdr.read_exact(&mut name_bytes)?; + let name = String::from_utf8(name_bytes)?; + Ok(Self { + name, + bias: utils::read_i32(&mut rdr)?, + }) + } +} + +// Left and right weight arrays of the TagModel are ordered as follows: +// +// tok1 tok2 tok3 ... +// +// tag1 1 5 9 +// tag2 2 6 . +// tag3 3 7 . +// ... 4 8 . +#[derive(Default)] +pub struct TagModel { + pub(crate) class_info: Vec, + pub(crate) left_char_model: NgramModel, + pub(crate) right_char_model: NgramModel, + pub(crate) self_char_model: NgramModel, +} + +impl TagModel { + pub fn serialize(&self, mut wtr: W) -> Result<()> + where + W: Write, + { + utils::write_u32(&mut wtr, self.class_info.len().try_into().unwrap())?; + for cls in &self.class_info { + cls.serialize(&mut wtr)?; + } + self.left_char_model.serialize(&mut wtr)?; + self.right_char_model.serialize(&mut wtr)?; + self.self_char_model.serialize(&mut wtr)?; + Ok(()) + } + + pub fn deserialize(mut rdr: R) -> Result + where + R: Read, + { + let n_class = utils::read_u32(&mut rdr)?; + let mut class_info = vec![]; + for _ in 0..n_class { + class_info.push(TagClassInfo::deserialize(&mut rdr)?); + } + Ok(Self { + class_info, + left_char_model: NgramModel::::deserialize(&mut rdr)?, + right_char_model: NgramModel::::deserialize(&mut rdr)?, + self_char_model: NgramModel::::deserialize(&mut rdr)?, + }) + } +} diff --git a/vaporetto/src/tag_trainer.rs b/vaporetto/src/tag_trainer.rs new file mode 100644 index 00000000..6717ce2e --- /dev/null +++ b/vaporetto/src/tag_trainer.rs @@ -0,0 +1,194 @@ +use std::collections::BTreeMap; + +use liblinear::LibLinearModel; + +use crate::errors::{Result, VaporettoError}; +use crate::feature::{StringNgramFeature, TagExampleGenerator, TagFeature}; +use crate::ngram_model::{NgramData, NgramModel}; +use crate::sentence::Sentence; +use crate::tag_model::{TagClassInfo, TagModel}; +use crate::trainer::{Indexer, SolverType, QUANTIZE_BIT_DEPTH}; + +pub struct TagTrainer<'a> { + example_generator: TagExampleGenerator, + char_window_size: usize, + feature_ids: Indexer>, + tag_ids: Indexer, + xs: Vec>, + ys: Vec, +} + +impl<'a> TagTrainer<'a> { + pub fn new(char_ngram_size: usize, char_window_size: usize) -> Self { + Self { + example_generator: TagExampleGenerator::new(char_ngram_size, char_window_size), + char_window_size, + feature_ids: Indexer::new(), + tag_ids: Indexer::new(), + xs: vec![], + ys: vec![], + } + } + + pub fn push_sentence(&mut self, s: &'a Sentence) -> Result<()> { + let examples = self.example_generator.generate(s)?; + for example in examples { + let mut feature_ids = BTreeMap::new(); + for f in &example.features { + let fid = self.feature_ids.get_id(f); + *feature_ids + .entry((fid + 1).try_into().unwrap()) + .or_insert(0.0) += 1.0; + } + self.xs.push(feature_ids.into_iter().collect()); + self.ys + .push(self.tag_ids.get_id(example.tag.as_str()) as f64); + } + Ok(()) + } + + pub fn n_features(&self) -> usize { + self.feature_ids.len() + } + + pub fn train(self, epsilon: f64, cost: f64, solver: SolverType) -> Result { + if self.xs.is_empty() { + // Returns an empty model if there is no training data. + return Ok(TagModel::default()); + } + + let mut builder = liblinear::Builder::new(); + let training_input = liblinear::util::TrainingInput::from_sparse_features(self.ys, self.xs) + .map_err(|e| VaporettoError::invalid_model(format!("liblinear error: {:?}", e)))?; + builder.problem().input_data(training_input).bias(1.0); + builder + .parameters() + .solver_type(solver.into()) + .stopping_criterion(epsilon) + .constraints_violation_cost(cost); + let model = builder + .build_model() + .map_err(|e| VaporettoError::invalid_model(e.to_string()))?; + + // Uses BTreeMap to increase compression ratio. + let mut left_char_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); + let mut right_char_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); + let mut self_char_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); + + let mut weight_max = 0.; + for i in 0..self.tag_ids.len() as i32 { + let weight = model.label_bias(i).abs(); + if weight > weight_max { + weight_max = weight; + } + for fid in 0..model.num_features() { + let weight = model.feature_coefficient(fid as i32, i).abs(); + if weight > weight_max { + weight_max = weight; + } + } + } + let quantize_multiplier = weight_max / ((1 << (QUANTIZE_BIT_DEPTH - 1)) - 1) as f64; + + let mut class_info = vec![]; + + for i in 0..self.tag_ids.len() { + class_info.push(TagClassInfo { + name: self.tag_ids.keys()[model.labels()[i] as usize].clone(), + bias: (model.label_bias(i as i32) / quantize_multiplier) as i32, + }); + + for (fid, feature) in self.feature_ids.keys().iter().enumerate() { + let raw_weight = model.feature_coefficient(fid as i32 + 1, i as i32); + let weight = (raw_weight / quantize_multiplier) as i32; + + if weight == 0 { + continue; + } + + match feature { + TagFeature::LeftCharacterNgram(StringNgramFeature { + rel_position, + ngram, + }) => { + let pos = -rel_position - 1; + let idx = i + pos as usize * self.tag_ids.len(); + if let Some(weights) = left_char_weights.get_mut(*ngram) { + weights[idx] = weight; + } else { + let mut weights = vec![0; self.char_window_size * self.tag_ids.len()]; + weights[idx] = weight; + left_char_weights.insert(ngram.to_string(), weights); + } + } + TagFeature::LeftCharacterNgramBos(StringNgramFeature { + rel_position, + ngram, + }) => { + let pos = -rel_position - 1; + let idx = i + pos as usize * self.tag_ids.len(); + let ngram = "\0".to_string() + *ngram; + left_char_weights.entry(ngram).or_insert_with(|| { + vec![0; self.char_window_size * self.tag_ids.len()] + })[idx] = weight; + } + TagFeature::RightCharacterNgram(StringNgramFeature { + rel_position, + ngram, + }) => { + let pos = self.char_window_size as isize - rel_position; + let idx = i as usize + pos as usize * self.tag_ids.len(); + if let Some(weights) = right_char_weights.get_mut(*ngram) { + weights[idx] = weight; + } else { + let mut weights = vec![0; self.char_window_size * self.tag_ids.len()]; + weights[idx] = weight; + right_char_weights.insert(ngram.to_string(), weights); + } + } + TagFeature::RightCharacterNgramEos(StringNgramFeature { + rel_position, + ngram, + }) => { + let pos = self.char_window_size as isize - rel_position; + let idx = i as usize + pos as usize * self.tag_ids.len(); + let ngram = ngram.to_string() + "\0"; + right_char_weights.entry(ngram).or_insert_with(|| { + vec![0; self.char_window_size * self.tag_ids.len()] + })[idx] = weight; + } + TagFeature::Character(ngram) => { + if let Some(weights) = self_char_weights.get_mut(*ngram) { + weights[i as usize] = weight; + } else { + let mut weights = vec![0; self.tag_ids.len()]; + weights[i as usize] = weight; + self_char_weights.insert(ngram.to_string(), weights); + } + } + }; + } + } + Ok(TagModel { + class_info, + left_char_model: NgramModel::new( + left_char_weights + .into_iter() + .map(|(ngram, weights)| NgramData { ngram, weights }) + .collect(), + ), + right_char_model: NgramModel::new( + right_char_weights + .into_iter() + .map(|(ngram, weights)| NgramData { ngram, weights }) + .collect(), + ), + self_char_model: NgramModel::new( + self_char_weights + .into_iter() + .map(|(ngram, weights)| NgramData { ngram, weights }) + .collect(), + ), + }) + } +} diff --git a/vaporetto/src/trainer.rs b/vaporetto/src/trainer.rs index b81de851..96b1fc5c 100644 --- a/vaporetto/src/trainer.rs +++ b/vaporetto/src/trainer.rs @@ -1,12 +1,64 @@ +use std::borrow::Borrow; use std::collections::BTreeMap; +use std::collections::HashMap; +use std::hash::Hash; use std::str::FromStr; -use anyhow::{anyhow, Result}; +use liblinear::LibLinearModel; -use crate::feature::{ExampleGenerator, FeatureExtractor}; +use crate::dict_model::{DictModel, DictWeight, WordWeightRecord}; +use crate::errors::{Result, VaporettoError}; +use crate::feature::{ + BoundaryExampleGenerator, BoundaryFeature, BytesNgramFeature, DictionaryWordFeature, + DictionaryWordPosition, StringNgramFeature, +}; use crate::model::Model; -use crate::sentence::Sentence; -use crate::utils::FeatureIDManager; +use crate::ngram_model::{NgramData, NgramModel}; +use crate::sentence::{BoundaryType, Sentence}; +use crate::tag_trainer::TagTrainer; + +// Bit depth for weight quantization. +pub const QUANTIZE_BIT_DEPTH: u8 = 16; + +pub struct Indexer { + ids: HashMap, + keys: Vec, +} + +impl Indexer +where + K: Eq + Hash, +{ + pub fn new() -> Self { + Self { + ids: HashMap::new(), + keys: vec![], + } + } + + pub fn get_id(&mut self, key: &Q) -> usize + where + K: Borrow, + Q: ToOwned + Eq + Hash, + { + if let Some(&id) = self.ids.get(key) { + id + } else { + let id = self.ids.len(); + self.keys.push(key.to_owned()); + self.ids.insert(key.to_owned(), id); + id + } + } + + pub fn len(&self) -> usize { + self.keys.len() + } + + pub fn keys(&self) -> &[K] { + &self.keys + } +} /// Solver type. #[derive(Clone, Copy, Debug)] @@ -69,21 +121,46 @@ impl From for liblinear::SolverType { } } -/// Dataset manager. +/// Trainer. +/// +/// # Examples +/// +/// ```no_run +/// use std::fs::File; +/// use std::io::{prelude::*, BufReader, BufWriter}; +/// +/// use vaporetto::{Sentence, SolverType, Trainer}; +/// +/// let mut train_sents = vec![]; +/// let f = BufReader::new(File::open("dataset-train.txt").unwrap()); +/// for (i, line) in f.lines().enumerate() { +/// train_sents.push(Sentence::from_tokenized(line.unwrap()).unwrap()); +/// } +/// +/// let dict: Vec = vec![]; +/// let mut trainer = Trainer::new(3, 3, 3, 3, &dict, 0).unwrap(); +/// for (i, s) in train_sents.iter().enumerate() { +/// trainer.push_sentence(s); +/// } +/// +/// let model = trainer.train(0.01, 1., SolverType::L1RegularizedL2LossSVC).unwrap(); +/// let mut f = BufWriter::new(File::create("model.bin").unwrap()); +/// model.write(&mut f).unwrap(); +/// ``` #[cfg_attr(docsrs, doc(cfg(feature = "train")))] -pub struct Dataset<'a> { - dictionary: Vec>, - feature_extractor: FeatureExtractor, - example_generator: ExampleGenerator, +pub struct Trainer<'a> { + dictionary: Vec, + example_generator: BoundaryExampleGenerator, char_window_size: usize, type_window_size: usize, - dict_word_max_size: usize, - fid_manager: FeatureIDManager<'a>, + dict_max_word_size: usize, + feature_ids: Indexer>, xs: Vec>, ys: Vec, + tag_trainer: TagTrainer<'a>, } -impl<'a> Dataset<'a> { +impl<'a> Trainer<'a> { /// Creates a new dataset manager. /// /// # Arguments @@ -93,7 +170,7 @@ impl<'a> Dataset<'a> { /// * `type_ngram_size` - The character type n-gram length. /// * `type_window_size` - The character type window size. /// * `dictionary` - A word dictionary. - /// * `dict_word_max_size` - Dictionary words greater than this value will be grouped together. + /// * `dict_max_word_size` - Dictionary words greater than this value will be grouped together. /// /// # Returns /// @@ -108,7 +185,7 @@ impl<'a> Dataset<'a> { type_ngram_size: usize, type_window_size: usize, dictionary: D, - dict_word_max_size: usize, + dict_max_word_size: usize, ) -> Result where D: AsRef<[P]>, @@ -118,21 +195,23 @@ impl<'a> Dataset<'a> { dictionary: dictionary .as_ref() .iter() - .map(|word| (word.as_ref() as &[u8]).to_vec()) + .map(|word| (word.as_ref() as &str).to_string()) .collect(), - feature_extractor: FeatureExtractor::new( + example_generator: BoundaryExampleGenerator::new( char_ngram_size, type_ngram_size, - dictionary, - dict_word_max_size, + char_window_size, + type_window_size, + Some(dictionary.as_ref()).filter(|d| !d.is_empty()), + dict_max_word_size, )?, - example_generator: ExampleGenerator::new(char_window_size, type_window_size), char_window_size, type_window_size, - dict_word_max_size, - fid_manager: FeatureIDManager::default(), + dict_max_word_size, + feature_ids: Indexer::new(), xs: vec![], ys: vec![], + tag_trainer: TagTrainer::new(char_ngram_size, char_window_size), }) } @@ -141,22 +220,26 @@ impl<'a> Dataset<'a> { /// # Arguments /// /// * `s` - A sentence. - pub fn push_sentence(&mut self, s: &'a Sentence) { - let feature_spans = self.feature_extractor.extract(s); - let examples = self.example_generator.generate(s, feature_spans, false); + /// + /// # Errors + /// + /// [`VaporettoError::InvalidArgument`] will be returned if the maximum number of feature has + /// been reached. + pub fn push_sentence(&mut self, s: &'a Sentence) -> Result<()> { + let examples = self.example_generator.generate(s); for example in examples { let mut feature_ids = BTreeMap::new(); - for f in example.features { - let fid = self.fid_manager.get_id(f) + 1; - if let Some(v) = feature_ids.get_mut(&fid) { - *v += 1.0; - } else { - feature_ids.insert(fid, 1.0); - } + for f in &example.features { + let fid = self.feature_ids.get_id(f); + *feature_ids + .entry((fid + 1).try_into().unwrap()) + .or_insert(0.0) += 1.0; } self.xs.push(feature_ids.into_iter().collect()); self.ys.push(example.label as u8 as f64); } + self.tag_trainer.push_sentence(s)?; + Ok(()) } /// Gets the number of features. @@ -165,94 +248,144 @@ impl<'a> Dataset<'a> { /// /// The number of features. pub fn n_features(&self) -> usize { - self.fid_manager.map.len() + self.feature_ids.len() } -} - -/// Trainer. -/// -/// # Examples -/// -/// ```no_run -/// use std::fs::File; -/// use std::io::{prelude::*, BufReader, BufWriter}; -/// -/// use vaporetto::{Dataset, Sentence, SolverType, Trainer}; -/// -/// let mut train_sents = vec![]; -/// let f = BufReader::new(File::open("dataset-train.txt").unwrap()); -/// for (i, line) in f.lines().enumerate() { -/// train_sents.push(Sentence::from_tokenized(line.unwrap()).unwrap()); -/// } -/// -/// let dict: Vec = vec![]; -/// let mut dataset = Dataset::new(3, 3, 3, 3, &dict, 0).unwrap(); -/// for (i, s) in train_sents.iter().enumerate() { -/// dataset.push_sentence(s); -/// } -/// -/// let trainer = Trainer::new(0.01, 1., 1.); -/// let model = trainer.train(dataset, SolverType::L1RegularizedL2LossSVC).unwrap(); -/// let mut f = BufWriter::new(File::create("model.bin").unwrap()); -/// model.write(&mut f).unwrap(); -/// ``` -#[cfg_attr(docsrs, doc(cfg(feature = "train")))] -pub struct Trainer { - epsilon: f64, - cost: f64, - bias: f64, -} -impl Trainer { - /// Creates a new trainer. - /// - /// # Arguments - /// - /// * `epsilon` - The tolerance of the termination criterion. - /// * `cost` - The parameter C. - /// * `bias` - The bias term. + /// Gets the number of tag features. /// /// # Returns /// - /// A new trainer. - pub const fn new(epsilon: f64, cost: f64, bias: f64) -> Self { - Self { - epsilon, - cost, - bias, - } + /// The number of tag features. + pub fn n_tag_features(&self) -> usize { + self.tag_trainer.n_features() } - /// Trains a given dataset. + /// Trains word boundaries. /// /// # Arguments /// - /// * `dataset` - A dataset. + /// * `epsilon` - The tolerance of the termination criterion. + /// * `cost` - The parameter C. /// * `solver` - Solver type. /// /// # Returns /// /// A trained model. - pub fn train(&self, dataset: Dataset, solver: SolverType) -> Result { + pub fn train(self, epsilon: f64, cost: f64, solver: SolverType) -> Result { let mut builder = liblinear::Builder::new(); - let training_input = - liblinear::util::TrainingInput::from_sparse_features(dataset.ys, dataset.xs) - .map_err(|e| anyhow!("liblinear error: {:?}", e))?; - builder.problem().input_data(training_input).bias(self.bias); + let training_input = liblinear::util::TrainingInput::from_sparse_features(self.ys, self.xs) + .map_err(|e| VaporettoError::invalid_model(format!("liblinear error: {:?}", e)))?; + builder.problem().input_data(training_input).bias(1.0); builder .parameters() .solver_type(solver.into()) - .stopping_criterion(self.epsilon) - .constraints_violation_cost(self.cost); - let model = builder.build_model().map_err(|e| anyhow!(e.to_string()))?; - - Ok(Model::from_liblinear_model( - model, - dataset.fid_manager, - dataset.dictionary, - dataset.char_window_size, - dataset.type_window_size, - dataset.dict_word_max_size, - )) + .stopping_criterion(epsilon) + .constraints_violation_cost(cost); + let model = builder + .build_model() + .map_err(|e| VaporettoError::invalid_model(e.to_string()))?; + + let wb_idx = model + .labels() + .iter() + .position(|&cls| BoundaryType::WordBoundary as i32 == cls) + .unwrap() as i32; + + let bias = model.label_bias(wb_idx); + + // Uses BTreeMap to increase compression ratio. + let mut char_ngram_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); + let mut type_ngram_weights: BTreeMap<_, Vec<_>> = BTreeMap::new(); + let mut dict_weights = vec![DictWeight::default(); self.dict_max_word_size]; + + let mut weight_max = bias.abs(); + for fid in 0..model.num_features() { + let weight = model.feature_coefficient(fid as i32, wb_idx).abs(); + if weight > weight_max { + weight_max = weight; + } + } + let quantize_multiplier = weight_max / ((1 << (QUANTIZE_BIT_DEPTH - 1)) - 1) as f64; + + let bias = (bias / quantize_multiplier) as i32; + + for (fid, feature) in self.feature_ids.keys().iter().enumerate() { + let raw_weight = model.feature_coefficient(fid as i32 + 1, wb_idx); + let weight = (raw_weight / quantize_multiplier) as i32; + + if weight == 0 { + continue; + } + + match feature { + BoundaryFeature::CharacterNgram(StringNgramFeature { + rel_position, + ngram, + }) => { + let len = ngram.chars().count(); + let pos = self.char_window_size as isize - len as isize - rel_position; + if let Some(weights) = char_ngram_weights.get_mut(*ngram) { + weights[pos as usize] = weight; + } else { + let mut weights = vec![0; self.char_window_size * 2 - len + 1]; + weights[pos as usize] = weight; + char_ngram_weights.insert(ngram.to_string(), weights); + } + } + BoundaryFeature::CharacterTypeNgram(BytesNgramFeature { + rel_position, + ngram, + }) => { + let len = ngram.len(); + let pos = self.char_window_size as isize - len as isize - rel_position; + if let Some(weights) = type_ngram_weights.get_mut(*ngram) { + weights[pos as usize] = weight; + } else { + let mut weights = vec![0; self.char_window_size * 2 - len + 1]; + weights[pos as usize] = weight; + type_ngram_weights.insert(ngram.to_vec(), weights); + } + } + BoundaryFeature::DictionaryWord(DictionaryWordFeature { position, length }) => { + match position { + DictionaryWordPosition::Right => dict_weights[length - 1].right = weight, + DictionaryWordPosition::Inside => dict_weights[length - 1].inside = weight, + DictionaryWordPosition::Left => dict_weights[length - 1].left = weight, + } + } + }; + } + let tag_model = self.tag_trainer.train(epsilon, cost, solver)?; + Ok(Model { + char_ngram_model: NgramModel::new( + char_ngram_weights + .into_iter() + .map(|(ngram, weights)| NgramData { ngram, weights }) + .collect(), + ), + type_ngram_model: NgramModel::new( + type_ngram_weights + .into_iter() + .map(|(ngram, weights)| NgramData { ngram, weights }) + .collect(), + ), + dict_model: DictModel::new( + self.dictionary + .into_iter() + .map(|word| { + let idx = word.chars().count().min(dict_weights.len()) - 1; + WordWeightRecord { + word, + weights: dict_weights[idx], + comment: "".to_string(), + } + }) + .collect(), + ), + bias, + tag_model, + char_window_size: self.char_window_size, + type_window_size: self.type_window_size, + }) } } diff --git a/vaporetto/src/type_scorer.rs b/vaporetto/src/type_scorer.rs index 0254d663..c3928943 100644 --- a/vaporetto/src/type_scorer.rs +++ b/vaporetto/src/type_scorer.rs @@ -1,138 +1,160 @@ -use crate::model::ScoreValue; -use crate::sentence::Sentence; +use std::cell::RefCell; +use std::collections::BTreeMap; + use daachorse::DoubleArrayAhoCorasick; +use crate::errors::{Result, VaporettoError}; +use crate::ngram_model::NgramModel; +use crate::sentence::Sentence; +use crate::utils::AddWeight; + pub enum TypeScorer { Pma(TypeScorerPma), Cache(TypeScorerCache), } impl TypeScorer { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { - if window_size <= 3 { - Self::Cache(TypeScorerCache::new(pma, weights, window_size)) + pub fn new(model: NgramModel>, window_size: usize) -> Result { + Ok(if window_size <= 3 { + Self::Cache(TypeScorerCache::new(model, window_size)?) } else { - Self::Pma(TypeScorerPma::new(pma, weights, window_size)) - } + Self::Pma(TypeScorerPma::new(model, window_size)?) + }) } - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { match self { - TypeScorer::Pma(pma) => pma.add_scores(sentence, start, ys), - TypeScorer::Cache(cache) => cache.add_scores(sentence, start, ys), + TypeScorer::Pma(pma) => pma.add_scores(sentence, ys), + TypeScorer::Cache(cache) => cache.add_scores(sentence, ys), } } } pub struct TypeScorerPma { pma: DoubleArrayAhoCorasick, - weights: Vec>, + weights: Vec>, window_size: usize, } impl TypeScorerPma { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { - Self { - pma, - weights, - window_size, + pub fn new(model: NgramModel>, window_size: usize) -> Result { + // key: ngram, value: (weight, check) + let mut weights_map: BTreeMap, RefCell<(Vec, bool)>> = BTreeMap::new(); + + for d in model.data { + weights_map.insert(d.ngram, RefCell::new((d.weights, false))); } - } - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let type_start = if start >= self.window_size { - start + 1 - self.window_size - } else { - 0 - }; - let type_end = std::cmp::min( - start + ys.len() + self.window_size, - sentence.char_type.len(), - ); - let char_type = &sentence.char_type[type_start..type_end]; - let padding = start - type_start + 1; - for m in self.pma.find_overlapping_no_suffix_iter(&char_type) { - let offset = m.end() as isize - self.window_size as isize - padding as isize; - let weights = &self.weights[m.pattern()]; - if offset >= 0 { - for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) { - *y += w; + let mut stack = vec![]; + for (ngram, data) in &weights_map { + if data.borrow().1 { + continue; + } + stack.push(data); + for j in 1..ngram.len() { + if let Some(data) = weights_map.get(&ngram[j..]) { + stack.push(data); + if data.borrow().1 { + break; + } } - } else { - for (w, y) in weights[-offset as usize..].iter().zip(ys.iter_mut()) { - *y += w; + } + let mut data_from = stack.pop().unwrap(); + data_from.borrow_mut().1 = true; + while let Some(data_to) = stack.pop() { + let mut new_weight = data_from.borrow().0.clone(); + for (w1, w2) in new_weight.iter_mut().zip(&data_to.borrow().0) { + *w1 += w2; } + let new_data = (new_weight, true); + *data_to.borrow_mut() = new_data; + data_from = data_to; } } + let mut ngrams = vec![]; + let mut weights = vec![]; + for (ngram, data) in weights_map { + ngrams.push(ngram); + weights.push(data.into_inner().0); + } + let pma = DoubleArrayAhoCorasick::new(ngrams) + .map_err(|_| VaporettoError::invalid_model("invalid character type n-grams"))?; + Ok(Self { + pma, + weights, + window_size, + }) + } + + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { + for m in self + .pma + .find_overlapping_no_suffix_iter(&sentence.char_type) + { + let offset = m.end() as isize - self.window_size as isize - 1; + // Both the weights and the PMA always have the same number of items. + // Therefore, the following code is safe. + let weights = unsafe { self.weights.get_unchecked(m.value()) }; + weights.add_weight(ys, offset); + } } } pub struct TypeScorerCache { - scores: Vec, + scores: Vec, window_size: usize, sequence_mask: usize, } impl TypeScorerCache { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { + pub fn new(model: NgramModel>, window_size: usize) -> Result { + let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)) + .map_err(|_| VaporettoError::invalid_model("invalid character type n-grams"))?; + let mut weights = vec![]; + for d in model.data { + if d.weights.len() <= 2 * window_size - d.ngram.len() { + return Err(VaporettoError::invalid_model( + "invalid size of weight vector", + )); + } + weights.push(d.weights); + } + let sequence_size = window_size * 2; let all_sequences = ALPHABET_SIZE.pow(sequence_size as u32); let mut sequence = vec![0u8; sequence_size]; - let mut scores = vec![0 as ScoreValue; all_sequences]; + let mut scores = vec![0; all_sequences]; for (i, score) in scores.iter_mut().enumerate() { if !Self::seqid_to_seq(i, &mut sequence) { continue; } - let mut y = ScoreValue::default(); - for m in pma.find_overlapping_no_suffix_iter(&sequence) { - y += weights[m.pattern()][sequence_size - m.end()]; + let mut y = 0; + for m in pma.find_overlapping_iter(&sequence) { + y += weights[m.value()][sequence_size - m.end()]; } *score = y; } - Self { + Ok(Self { scores, window_size, sequence_mask: (1 << (ALPHABET_SHIFT * sequence_size)) - 1, - } + }) } - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let type_start = if start >= self.window_size { - start + 1 - self.window_size - } else { - 0 - }; - let type_end = std::cmp::min( - start + ys.len() + self.window_size, - sentence.char_type.len(), - ); - let char_type = &sentence.char_type[type_start..type_end]; - let offset = self.window_size + start; + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { let mut seqid = 0; - for i in 0..offset { - if let Some(ct) = char_type.get(i) { + for i in 0..self.window_size { + if let Some(ct) = sentence.char_type.get(i) { seqid = self.increment_seqid(seqid, *ct); } else { seqid = self.increment_seqid_without_char(seqid); }; } for (i, y) in ys.iter_mut().enumerate() { - if let Some(ct) = char_type.get(i + offset) { + if let Some(ct) = sentence.char_type.get(i + self.window_size) { seqid = self.increment_seqid(seqid, *ct); } else { seqid = self.increment_seqid_without_char(seqid); @@ -142,12 +164,12 @@ impl TypeScorerCache { } fn seqid_to_seq(mut seqid: usize, sequence: &mut [u8]) -> bool { - for i in (0..sequence.len()).rev() { + for type_id in sequence.iter_mut().rev() { let x = seqid & ALPHABET_MASK; if x == ALPHABET_MASK { return false; // invalid } - sequence[i] = ID_TO_TYPE[x]; + *type_id = ID_TO_TYPE[x]; seqid >>= ALPHABET_SHIFT; } assert_eq!(seqid, 0); @@ -155,7 +177,7 @@ impl TypeScorerCache { } #[inline(always)] - fn get_score(&self, seqid: usize) -> ScoreValue { + fn get_score(&self, seqid: usize) -> i32 { self.scores[seqid] } @@ -176,7 +198,7 @@ const ALPHABET_SIZE: usize = 8; const ALPHABET_MASK: usize = ALPHABET_SIZE - 1; const ALPHABET_SHIFT: usize = 3; const TYPE_TO_ID: [u32; 256] = make_type_to_id(); -const ID_TO_TYPE: [u8; 256] = make_id_to_type(); +const ID_TO_TYPE: [u8; ALPHABET_SIZE] = make_id_to_type(); const fn make_type_to_id() -> [u32; 256] { use crate::sentence::CharacterType::*; @@ -191,10 +213,10 @@ const fn make_type_to_id() -> [u32; 256] { type_to_id } -const fn make_id_to_type() -> [u8; 256] { +const fn make_id_to_type() -> [u8; ALPHABET_SIZE] { use crate::sentence::CharacterType::*; - let mut id_to_type = [0u8; 256]; + let mut id_to_type = [0u8; ALPHABET_SIZE]; id_to_type[1] = Digit as u8; id_to_type[2] = Roman as u8; id_to_type[3] = Hiragana as u8; diff --git a/vaporetto/src/utils.rs b/vaporetto/src/utils.rs index 47b51b80..8d926549 100644 --- a/vaporetto/src/utils.rs +++ b/vaporetto/src/utils.rs @@ -1,85 +1,171 @@ -#[cfg(feature = "train")] -use std::collections::HashMap; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::io::{self, Read, Write}; -#[cfg(feature = "train")] -use crate::feature::Feature; - -#[cfg(feature = "train")] -pub struct FeatureIDManager<'a> { - pub(crate) map: HashMap, u32>, +pub trait AddWeight { + fn add_weight(&self, target: &mut [i32], offset: isize); } -#[cfg(feature = "train")] -impl<'a> FeatureIDManager<'a> { - pub fn new() -> Self { - Self { - map: HashMap::new(), +impl AddWeight for Vec { + fn add_weight(&self, ys: &mut [i32], offset: isize) { + if offset >= 0 { + if let Some(ys) = ys.get_mut(offset as usize..) { + for (w, y) in self.iter().zip(ys) { + *y += w; + } + } + } else if let Some(ws) = self.get(-offset as usize..) { + for (w, y) in ws.iter().zip(ys.iter_mut()) { + *y += w; + } } } - - pub fn get_id(&mut self, feature: Feature<'a>) -> u32 { - self.map.get(&feature).copied().unwrap_or_else(|| { - let new_id = self.map.len() as u32; - self.map.insert(feature, new_id); - new_id - }) - } } -#[cfg(feature = "train")] -impl<'a> Default for FeatureIDManager<'a> { - fn default() -> Self { - Self::new() - } +pub trait MergableWeight { + fn from_two_weights(weight1: &Self, weight2: &Self, n_classes: usize) -> Self; } -#[cfg(feature = "train")] -pub struct StringIdManager { - pub(crate) map: HashMap, usize>, +pub struct WeightMerger { + map: BTreeMap>, + n_classes: usize, } -#[cfg(feature = "train")] -impl StringIdManager { - pub fn new() -> Self { +impl WeightMerger +where + W: MergableWeight, +{ + pub fn new(n_classes: usize) -> Self { Self { - map: HashMap::new(), + map: BTreeMap::new(), + n_classes, + } + } + + pub fn add(&mut self, ngram: &str, weight: W) { + if let Some(data) = self.map.get_mut(ngram) { + let (prev_weight, _) = &mut *data.borrow_mut(); + *prev_weight = W::from_two_weights(&weight, prev_weight, self.n_classes); + } else { + self.map + .insert(ngram.to_string(), RefCell::new((weight, false))); } } - pub fn get_id(&mut self, key: &[u8]) -> usize { - self.map.get(key).copied().unwrap_or_else(|| { - let new_id = self.map.len(); - self.map.insert(key.into(), new_id); - new_id - }) + pub fn merge(self) -> Vec<(String, W)> { + let mut stack = vec![]; + for (ngram, data) in &self.map { + if data.borrow().1 { + continue; + } + stack.push(data); + for (j, _) in ngram.char_indices().skip(1) { + if let Some(data) = self.map.get(&ngram[j..]) { + stack.push(data); + if data.borrow().1 { + break; + } + } + } + let mut data_from = stack.pop().unwrap(); + data_from.borrow_mut().1 = true; + while let Some(data_to) = stack.pop() { + let new_data = ( + W::from_two_weights(&data_from.borrow().0, &data_to.borrow().0, self.n_classes), + true, + ); + *data_to.borrow_mut() = new_data; + data_from = data_to; + } + } + self.map + .into_iter() + .map(|(ngram, weight)| (ngram, weight.into_inner().0)) + .collect() } } -#[cfg(test)] -#[allow(unused_macros)] -macro_rules! ct2u8 { - ( $( $v:path ),* ) => { - ct2u8!( $( $v, )* ) - }; - ( $( $v:path, )* ) => { - [ - $( - $v as u8, - )* - ] - }; +pub fn xor_or_zip_with(lhs: &Option, rhs: &Option, f: F) -> Option +where + T: Clone, + F: FnOnce(&T, &T) -> T, +{ + lhs.as_ref().map_or_else( + || rhs.clone(), + |x1| Some(rhs.as_ref().map_or_else(|| x1.clone(), |x2| f(x1, x2))), + ) +} + +#[cfg(feature = "kytea")] +pub fn read_u8(mut rdr: R) -> io::Result +where + R: Read, +{ + let mut buf = [0]; + rdr.read_exact(&mut buf)?; + Ok(buf[0]) +} + +#[cfg(feature = "kytea")] +pub fn read_u16(mut rdr: R) -> io::Result +where + R: Read, +{ + let mut buf = [0; 2]; + rdr.read_exact(&mut buf)?; + Ok(u16::from_le_bytes(buf)) +} + +#[cfg(feature = "kytea")] +pub fn read_i16(mut rdr: R) -> io::Result +where + R: Read, +{ + let mut buf = [0; 2]; + rdr.read_exact(&mut buf)?; + Ok(i16::from_le_bytes(buf)) +} + +pub fn read_u32(mut rdr: R) -> io::Result +where + R: Read, +{ + let mut buf = [0; 4]; + rdr.read_exact(&mut buf)?; + Ok(u32::from_le_bytes(buf)) +} + +pub fn write_u32(mut wtr: W, data: u32) -> io::Result<()> +where + W: Write, +{ + wtr.write_all(&data.to_le_bytes())?; + Ok(()) +} + +pub fn read_i32(mut rdr: R) -> io::Result +where + R: Read, +{ + let mut buf = [0; 4]; + rdr.read_exact(&mut buf)?; + Ok(i32::from_le_bytes(buf)) +} + +pub fn write_i32(mut wtr: W, data: i32) -> io::Result<()> +where + W: Write, +{ + wtr.write_all(&data.to_le_bytes())?; + Ok(()) } -#[cfg(test)] -macro_rules! ct2u8vec { - ( $( $v:path ),* ) => { - ct2u8vec!( $( $v, )* ) - }; - ( $( $v:path, )* ) => { - vec![ - $( - $v as u8, - )* - ] - }; +#[cfg(feature = "kytea")] +pub fn read_f64(mut rdr: R) -> io::Result +where + R: Read, +{ + let mut buf = [0; 8]; + rdr.read_exact(&mut buf)?; + Ok(f64::from_le_bytes(buf)) } diff --git a/vaporetto_rules/Cargo.toml b/vaporetto_rules/Cargo.toml index 3dda0fc9..2233f39f 100644 --- a/vaporetto_rules/Cargo.toml +++ b/vaporetto_rules/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vaporetto_rules" -version = "0.1.4" +version = "0.3.0" edition = "2018" authors = ["Koichi Akabe "] description = "Rule-base filters for Vaporetto" @@ -10,8 +10,7 @@ repository = "https://github.com/legalforce-research/vaporetto" readme = "README.md" keywords = ["japanese", "analyzer", "tokenizer", "morphological"] categories = ["text-processing"] -autotests = false [dependencies] -unicode-segmentation = "1.8.0" # MIT or Apache-2.0 -vaporetto = { path = "../vaporetto", version = "0.2.0" } # MIT or Apache-2.0 +unicode-segmentation = "1.9.0" # MIT or Apache-2.0 +vaporetto = { path = "../vaporetto", version = "0.3.0" } # MIT or Apache-2.0 diff --git a/vaporetto_rules/README.md b/vaporetto_rules/README.md index 6527833c..f8edbeac 100644 --- a/vaporetto_rules/README.md +++ b/vaporetto_rules/README.md @@ -8,6 +8,7 @@ vaporetto_rules is rule-base filters for Vaporetto. ```rust use std::fs::File; use std::io::BufReader; +use std::rc::Rc; use vaporetto::{CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ @@ -18,9 +19,9 @@ use vaporetto_rules::{ let mut f = BufReader::new(File::open("model.bin").unwrap()); let model = Model::read(&mut f).unwrap(); -let mut predictor = Predictor::new(model); +let mut predictor = Predictor::new(model).unwrap(); -let pre_filters: Vec>> = vec![ +let pre_filters: Vec> = vec![ Box::new(KyteaFullwidthFilter::new()), ]; let post_filters: Vec> = vec![ @@ -31,7 +32,9 @@ let post_filters: Vec> = vec![ let input = "Vaporettoは仲良し家族👨‍👨‍👧‍👦を離れ離れにさせません。" .to_string(); -let preproc_input = pre_filters.iter().fold(input, |s, filter| filter.filter(s)); +let input = Rc::new(input); +let preproc_input = pre_filters.iter().fold(input, |s, filter| Rc::new(filter.filter(&s))); +let preproc_input = Rc::try_unwrap(preproc_input).unwrap(); let sentence = Sentence::from_raw(preproc_input).unwrap(); let sentence = predictor.predict(sentence); diff --git a/vaporetto_rules/src/lib.rs b/vaporetto_rules/src/lib.rs index 8864c586..fb3585f5 100644 --- a/vaporetto_rules/src/lib.rs +++ b/vaporetto_rules/src/lib.rs @@ -7,6 +7,7 @@ //! ```no_run //! use std::fs::File; //! use std::io::BufReader; +//! use std::rc::Rc; //! //! use vaporetto::{CharacterType, Model, Predictor, Sentence}; //! use vaporetto_rules::{ @@ -17,20 +18,22 @@ //! //! let mut f = BufReader::new(File::open("model.bin").unwrap()); //! let model = Model::read(&mut f).unwrap(); -//! let mut predictor = Predictor::new(model); +//! let mut predictor = Predictor::new(model, false).unwrap(); //! -//! let pre_filters: Vec>> = vec![ -//! Box::new(KyteaFullwidthFilter::new()), +//! let pre_filters: Vec> = vec![ +//! Box::new(KyteaFullwidthFilter), //! ]; //! let post_filters: Vec> = vec![ -//! Box::new(ConcatGraphemeClustersFilter::new()), +//! Box::new(ConcatGraphemeClustersFilter), //! Box::new(KyteaWsConstFilter::new(CharacterType::Digit)), //! ]; //! //! let input = "Vaporettoは仲良し家族👨‍👨‍👧‍👦を離れ離れにさせません。" //! .to_string(); //! -//! let preproc_input = pre_filters.iter().fold(input, |s, filter| filter.filter(s)); +//! let input = Rc::new(input); +//! let preproc_input = pre_filters.iter().fold(input, |s, filter| Rc::new(filter.filter(&s))); +//! let preproc_input = Rc::try_unwrap(preproc_input).unwrap(); //! //! let sentence = Sentence::from_raw(preproc_input).unwrap(); //! let sentence = predictor.predict(sentence); @@ -49,7 +52,7 @@ pub mod string_filters; use vaporetto::Sentence; -pub trait SentenceFilter { +pub trait SentenceFilter: Send + Sync { /// Filter a specified sentence using rules. /// /// # Arguments: @@ -62,10 +65,7 @@ pub trait SentenceFilter { fn filter(&self, sentence: Sentence) -> Sentence; } -pub trait StringFilter -where - S: AsRef, -{ +pub trait StringFilter: Send + Sync { /// Filter a specified string using rules. /// /// # Arguments: @@ -75,5 +75,5 @@ where /// # Returns /// /// A processed string. - fn filter(&self, string: S) -> String; + fn filter(&self, string: &str) -> String; } diff --git a/vaporetto_rules/src/sentence_filters.rs b/vaporetto_rules/src/sentence_filters.rs index cd968e80..b701ec2a 100644 --- a/vaporetto_rules/src/sentence_filters.rs +++ b/vaporetto_rules/src/sentence_filters.rs @@ -2,6 +2,8 @@ mod concat_grapheme_clusters; mod kytea_wsconst; +mod split_linebreaks; pub use concat_grapheme_clusters::ConcatGraphemeClustersFilter; pub use kytea_wsconst::KyteaWsConstFilter; +pub use split_linebreaks::SplitLinebreaksFilter; diff --git a/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs b/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs index 287ae38b..39d149b2 100644 --- a/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs +++ b/vaporetto_rules/src/sentence_filters/concat_grapheme_clusters.rs @@ -4,47 +4,18 @@ use vaporetto::{BoundaryType, Sentence}; use crate::SentenceFilter; /// Grapheme cluster concatenator. +#[derive(Clone, Default)] pub struct ConcatGraphemeClustersFilter; -impl ConcatGraphemeClustersFilter { - /// Creates a new ConcatGraphemeClustersFilter. - /// - /// # Returns - /// - /// A new ConcatGraphemeClustersFilter. - pub const fn new() -> Self { - Self {} - } -} - -impl Default for ConcatGraphemeClustersFilter { - fn default() -> Self { - Self::new() - } -} - impl SentenceFilter for ConcatGraphemeClustersFilter { - /// Concatenates grapheme clusters. - /// - /// # Arguments: - /// - /// * `sentence` - Input sentence. - /// - /// # Returns - /// - /// A processed sentence. fn filter(&self, mut sentence: Sentence) -> Sentence { let mut tmp = sentence.boundaries().to_vec(); - for (i, c) in UnicodeSegmentation::grapheme_indices(sentence.to_raw_string(), true) { + for (i, c) in sentence.to_raw_string().grapheme_indices(true) { let start = sentence.get_char_pos(i).unwrap(); let end = sentence.get_char_pos(i + c.len()).unwrap() - 1; - for b in &mut tmp[start..end] { - *b = BoundaryType::NotWordBoundary; - } - } - for (b, t) in sentence.boundaries_mut().iter_mut().zip(&tmp) { - *b = *t; + tmp[start..end].fill(BoundaryType::NotWordBoundary); } + sentence.boundaries_mut().copy_from_slice(&tmp); sentence } } @@ -56,7 +27,7 @@ mod tests { #[test] fn test_concat_grapheme_clusters_no_boundary() { let s = Sentence::from_tokenized("\u{200d}").unwrap(); - let filter = ConcatGraphemeClustersFilter::new(); + let filter = ConcatGraphemeClustersFilter; let s = filter.filter(s); assert_eq!("\u{200d}", s.to_tokenized_string().unwrap()); } @@ -65,7 +36,7 @@ mod tests { fn test_concat_grapheme_clusters_zwj() { let s = Sentence::from_tokenized("\u{1f468} \u{200d} \u{1f469} \u{200d} \u{1f466}").unwrap(); - let filter = ConcatGraphemeClustersFilter::new(); + let filter = ConcatGraphemeClustersFilter; let s = filter.filter(s); assert_eq!( "\u{1f468}\u{200d}\u{1f469}\u{200d}\u{1f466}", @@ -76,7 +47,7 @@ mod tests { #[test] fn test_concat_grapheme_clusters_color() { let s = Sentence::from_tokenized("\u{1f44f} \u{1f3fd}").unwrap(); - let filter = ConcatGraphemeClustersFilter::new(); + let filter = ConcatGraphemeClustersFilter; let s = filter.filter(s); assert_eq!("\u{1f44f}\u{1f3fd}", s.to_tokenized_string().unwrap()); } @@ -84,7 +55,7 @@ mod tests { #[test] fn test_concat_grapheme_clusters_combined() { let s = Sentence::from_tokenized("これ は 手 \u{1f44f} \u{1f3fd} で す").unwrap(); - let filter = ConcatGraphemeClustersFilter::new(); + let filter = ConcatGraphemeClustersFilter; let s = filter.filter(s); assert_eq!( "これ は 手 \u{1f44f}\u{1f3fd} で す", diff --git a/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs b/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs index bd0d1318..07d69964 100644 --- a/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs +++ b/vaporetto_rules/src/sentence_filters/kytea_wsconst.rs @@ -3,6 +3,7 @@ use vaporetto::{BoundaryType, CharacterType, Sentence}; use crate::SentenceFilter; /// Character type concatenator. This filter works like KyTea's wsconst option. +#[derive(Clone)] pub struct KyteaWsConstFilter { char_type: CharacterType, } @@ -23,26 +24,14 @@ impl KyteaWsConstFilter { } impl SentenceFilter for KyteaWsConstFilter { - /// Concatenates consecutive character types. - /// - /// # Arguments: - /// - /// * `sentence` - Input sentence. - /// - /// # Returns - /// - /// A processed sentence. fn filter(&self, mut sentence: Sentence) -> Sentence { let t_flag = self.char_type as u8; - let mut tmp = sentence.boundaries().to_vec(); - for (i, (b, &t)) in tmp.iter_mut().zip(sentence.char_types()).enumerate() { - if t == t_flag && t == sentence.char_types()[i + 1] { + let (_, char_types, boundaries) = sentence.chars_and_boundaries_mut(); + for ((t1, t2), b) in char_types.iter().zip(&char_types[1..]).zip(boundaries) { + if *t1 == t_flag && *t2 == t_flag { *b = BoundaryType::NotWordBoundary; } } - for (b, t) in sentence.boundaries_mut().iter_mut().zip(&tmp) { - *b = *t; - } sentence } } diff --git a/vaporetto_rules/src/sentence_filters/split_linebreaks.rs b/vaporetto_rules/src/sentence_filters/split_linebreaks.rs new file mode 100644 index 00000000..71156946 --- /dev/null +++ b/vaporetto_rules/src/sentence_filters/split_linebreaks.rs @@ -0,0 +1,51 @@ +use vaporetto::{BoundaryType, Sentence}; + +use crate::SentenceFilter; + +/// Line breaks splitter. +#[derive(Clone, Default)] +pub struct SplitLinebreaksFilter; + +impl SentenceFilter for SplitLinebreaksFilter { + fn filter(&self, mut sentence: Sentence) -> Sentence { + let (chars, _, boundaries) = sentence.chars_and_boundaries_mut(); + for ((c1, c2), b) in chars.iter().zip(&chars[1..]).zip(boundaries) { + match (*c1, *c2) { + ('\r' | '\n', _) | (_, '\r' | '\n') => { + *b = BoundaryType::WordBoundary; + } + _ => {} + } + } + sentence + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_lf() { + let s = Sentence::from_tokenized("前の行\n次の行").unwrap(); + let filter = SplitLinebreaksFilter; + let s = filter.filter(s); + assert_eq!("前の行 \n 次の行", s.to_tokenized_string().unwrap()); + } + + #[test] + fn test_split_cr() { + let s = Sentence::from_tokenized("前の行\r次の行").unwrap(); + let filter = SplitLinebreaksFilter; + let s = filter.filter(s); + assert_eq!("前の行 \r 次の行", s.to_tokenized_string().unwrap()); + } + + #[test] + fn test_split_crlf() { + let s = Sentence::from_tokenized("前の行\r\n次の行").unwrap(); + let filter = SplitLinebreaksFilter; + let s = filter.filter(s); + assert_eq!("前の行 \r \n 次の行", s.to_tokenized_string().unwrap()); + } +} diff --git a/vaporetto_rules/src/string_filters/kytea_fullwidth.rs b/vaporetto_rules/src/string_filters/kytea_fullwidth.rs index befead3b..abefc99d 100644 --- a/vaporetto_rules/src/string_filters/kytea_fullwidth.rs +++ b/vaporetto_rules/src/string_filters/kytea_fullwidth.rs @@ -1,40 +1,12 @@ use crate::StringFilter; /// Half-width to full-width filter. This filter works like KyTea's preprocessor. +#[derive(Clone, Default)] pub struct KyteaFullwidthFilter; -impl KyteaFullwidthFilter { - /// Creates a new KyteaFullwidthFilter. - /// - /// # Returns - /// - /// A new KyteaFullwidthFilter. - pub const fn new() -> Self { - Self {} - } -} - -impl Default for KyteaFullwidthFilter { - fn default() -> Self { - Self::new() - } -} - -impl StringFilter for KyteaFullwidthFilter -where - S: AsRef, -{ - /// Replace alphanumerics and symbols to full-width characters. - /// - /// # Arguments: - /// - /// * `text` - Input text. - /// - /// # Returns - /// - /// A processed text. - fn filter(&self, string: S) -> String { - let mut chars: Vec<_> = string.as_ref().chars().collect(); +impl StringFilter for KyteaFullwidthFilter { + fn filter(&self, string: &str) -> String { + let mut chars: Vec<_> = string.chars().collect(); for c in &mut chars { *c = match *c { 'a' => 'a', diff --git a/vaporetto_tantivy/Cargo.toml b/vaporetto_tantivy/Cargo.toml new file mode 100644 index 00000000..4694b884 --- /dev/null +++ b/vaporetto_tantivy/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "vaporetto_tantivy" +version = "0.3.0" +edition = "2021" +authors = ["Koichi Akabe "] +description = "Vaporetto Tokenizer for Tantivy" +license = "MIT OR Apache-2.0" +homepage = "https://github.com/legalforce-research/vaporetto" +repository = "https://github.com/legalforce-research/vaporetto" +readme = "README.md" +keywords = ["japanese", "tokenizer", "tantivy"] +categories = ["text-processing"] + +[dependencies] +vaporetto = { path = "../vaporetto", version = "0.3.0" } # MIT or Apache-2.0 +vaporetto_rules = { path = "../vaporetto_rules", version = "0.3.0" } # MIT or Apache-2.0 +tantivy = "0.16" # MIT + +[dev-dependencies] +ruzstd = "0.2.4" # MIT diff --git a/vaporetto_tantivy/README.md b/vaporetto_tantivy/README.md new file mode 100644 index 00000000..acbb8d73 --- /dev/null +++ b/vaporetto_tantivy/README.md @@ -0,0 +1,40 @@ +# vaporetto_tantivy + +Vaporetto is a fast and lightweight pointwise prediction based tokenizer. +vaporetto_tantivy is a crate to use Vaporetto in [Tantivy](https://github.com/quickwit-oss/tantivy). + +# Example + +```rust +use std::fs::File; +use std::io::{Read, BufReader}; + +use tantivy::schema::{IndexRecordOption, Schema, TextFieldIndexing, TextOptions}; +use tantivy::Index; +use vaporetto::Model; +use vaporetto_tantivy::VaporettoTokenizer; + +let mut schema_builder = Schema::builder(); +let text_field_indexing = TextFieldIndexing::default() + .set_tokenizer("ja_vaporetto") + .set_index_option(IndexRecordOption::WithFreqsAndPositions); +let text_options = TextOptions::default() + .set_indexing_options(text_field_indexing) + .set_stored(); +schema_builder.add_text_field("title", text_options); +let schema = schema_builder.build(); +let index = Index::create_in_ram(schema); + +// Loads a model with decompression. +let mut f = BufReader::new(File::open("bccwj-suw+unidic.model.zst").unwrap()); +let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap(); +let mut buff = vec![]; +decoder.read_to_end(&mut buff).unwrap(); +let model = Model::read(&mut buff.as_slice()).unwrap(); + +// Creates VaporettoTokenizer with wsconst=DGR. +let tokenizer = VaporettoTokenizer::new(model, "DGR").unwrap(); +index + .tokenizers() + .register("ja_vaporetto", tokenizer); +``` diff --git a/vaporetto_tantivy/src/lib.rs b/vaporetto_tantivy/src/lib.rs new file mode 100644 index 00000000..ec2f375e --- /dev/null +++ b/vaporetto_tantivy/src/lib.rs @@ -0,0 +1,448 @@ +//! # vaporetto_tantivy +//! +//! Vaporetto Tokenizer for Tantivy +//! +//! ## Examples +//! +//! ```no_run +//! use std::fs::File; +//! use std::io::{Read, BufReader}; +//! +//! use tantivy::tokenizer::Tokenizer; +//! use vaporetto::Model; +//! use vaporetto_tantivy::VaporettoTokenizer; +//! +//! let mut f = BufReader::new(File::open("model.zst").unwrap()); +//! let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap(); +//! let mut buff = vec![]; +//! decoder.read_to_end(&mut buff).unwrap(); +//! let model = Model::read(&mut buff.as_slice()).unwrap(); +//! +//! let tokenizer = VaporettoTokenizer::new(model, "DGR").unwrap(); +//! +//! let mut stream = tokenizer.token_stream("東京特許許可局"); +//! +//! let token = stream.next().unwrap(); +//! assert_eq!(token.text, "東京"); +//! assert_eq!(token.offset_from, 0); +//! assert_eq!(token.offset_to, 6); +//! assert_eq!(token.position, 0); +//! +//! let token = stream.next().unwrap(); +//! assert_eq!(token.text, "特許"); +//! assert_eq!(token.offset_from, 6); +//! assert_eq!(token.offset_to, 12); +//! assert_eq!(token.position, 1); +//! +//! let token = stream.next().unwrap(); +//! assert_eq!(token.text, "許可"); +//! assert_eq!(token.offset_from, 12); +//! assert_eq!(token.offset_to, 18); +//! assert_eq!(token.position, 2); +//! +//! let token = stream.next().unwrap(); +//! assert_eq!(token.text, "局"); +//! assert_eq!(token.offset_from, 18); +//! assert_eq!(token.offset_to, 21); +//! assert_eq!(token.position, 3); +//! +//! assert!(stream.next().is_none()); +/// ``` +use std::sync::Arc; + +use tantivy::tokenizer::{BoxTokenStream, Token, TokenStream, Tokenizer}; +use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence}; +use vaporetto_rules::{ + sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter, SplitLinebreaksFilter}, + string_filters::KyteaFullwidthFilter, + SentenceFilter, StringFilter, +}; + +/// Tokenize the text using Vaporetto. +#[derive(Clone)] +pub struct VaporettoTokenizer { + predictor: Arc, + prefilter: KyteaFullwidthFilter, + postfilters: Vec>, +} + +impl VaporettoTokenizer { + /// Creates a new VaporettoTokenizer. + /// + /// # Arguments + /// + /// * `model` - A model data of Vaporetto. + /// * `wsconst` - Character types that the tokenizer does not segment. + /// D: Digit, R: Roman, H: Hiragana, T: Katakana, K: Kanji, O: Other, + /// G: Grapheme cluster. + /// + /// # Errors + /// + /// Error is returned when + /// - the model is invalid, or + /// - `wsconst` contains an invalid character type. + pub fn new(model: Model, wsconst: &str) -> Result> { + let mut postfilters: Vec> = vec![Arc::new(SplitLinebreaksFilter)]; + for c in wsconst.chars() { + postfilters.push(match c { + 'D' => Arc::new(KyteaWsConstFilter::new(CharacterType::Digit)), + 'R' => Arc::new(KyteaWsConstFilter::new(CharacterType::Roman)), + 'H' => Arc::new(KyteaWsConstFilter::new(CharacterType::Hiragana)), + 'T' => Arc::new(KyteaWsConstFilter::new(CharacterType::Katakana)), + 'K' => Arc::new(KyteaWsConstFilter::new(CharacterType::Kanji)), + 'O' => Arc::new(KyteaWsConstFilter::new(CharacterType::Other)), + 'G' => Arc::new(ConcatGraphemeClustersFilter), + _ => return Err("Could not parse a wsconst value".into()), + }); + } + Ok(Self { + predictor: Arc::new(Predictor::new(model, false)?), + prefilter: KyteaFullwidthFilter, + postfilters, + }) + } +} + +pub struct VaporettoTokenStream<'a> { + text: &'a str, + token: Token, + boundary_pos: Vec, + offset_to: usize, + position: usize, +} + +impl Tokenizer for VaporettoTokenizer { + fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> { + if text.is_empty() { + return BoxTokenStream::from(VaporettoTokenStream { + text, + boundary_pos: vec![], + token: Token::default(), + offset_to: 0, + position: 0, + }); + } + + // pre filter + let prefiltered_text = self.prefilter.filter(text); + let prefiltered_sentence = Sentence::from_raw(prefiltered_text).unwrap(); + + // tokenize + let tokenized_sentence = self.predictor.predict(prefiltered_sentence); + + // post filter + let postfiltered_sentence = self + .postfilters + .iter() + .fold(tokenized_sentence, |s, filter| filter.filter(s)); + + let mut char_indices = text.char_indices(); + char_indices.next(); + let mut boundary_pos = Vec::with_capacity(postfiltered_sentence.chars().len()); + for ((i, _), &b) in char_indices.zip(postfiltered_sentence.boundaries()) { + if b == BoundaryType::WordBoundary { + boundary_pos.push(i); + } + } + boundary_pos.push(text.len()); + + BoxTokenStream::from(VaporettoTokenStream { + text, + token: Token::default(), + boundary_pos, + offset_to: 0, + position: 0, + }) + } +} + +impl<'a> TokenStream for VaporettoTokenStream<'a> { + fn advance(&mut self) -> bool { + if self.position < self.boundary_pos.len() { + self.token.offset_from = self.offset_to; + self.offset_to = self.boundary_pos[self.position]; + self.token.offset_to = self.offset_to; + self.token.text.clear(); + self.token + .text + .push_str(&self.text[self.token.offset_from..self.token.offset_to]); + self.token.position = self.position; + self.token.position_length = self.boundary_pos.len(); + self.position += 1; + true + } else { + false + } + } + + fn token(&self) -> &Token { + &self.token + } + + fn token_mut(&mut self) -> &mut Token { + &mut self.token + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::{Cursor, Read}; + + use tantivy::tokenizer::TextAnalyzer; + + fn token_stream_helper(text: &str, wsconst: &str) -> Vec { + let mut f = Cursor::new(include_bytes!("../test_model/model.zst")); + let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap(); + let mut buff = vec![]; + decoder.read_to_end(&mut buff).unwrap(); + let model = Model::read(&mut buff.as_slice()).unwrap(); + let a = TextAnalyzer::from(VaporettoTokenizer::new(model, wsconst).unwrap()); + let mut token_stream = a.token_stream(text); + let mut tokens: Vec = vec![]; + let mut add_token = |token: &Token| { + tokens.push(token.clone()); + }; + token_stream.process(&mut add_token); + tokens + } + + #[test] + fn test_tokenize_empty() { + let tokens = token_stream_helper("", ""); + + assert_eq!(tokens.len(), 0); + } + + #[test] + fn test_tokenizer_tokyo() { + let tokens = token_stream_helper("東京特許許可局", ""); + + assert_eq!(tokens.len(), 4); + + let token = &tokens[0]; + assert_eq!(token.text, "東京"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 4); + + let token = &tokens[1]; + assert_eq!(token.text, "特許"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 12); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 4); + + let token = &tokens[2]; + assert_eq!(token.text, "許可"); + assert_eq!(token.offset_from, 12); + assert_eq!(token.offset_to, 18); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 4); + + let token = &tokens[3]; + assert_eq!(token.text, "局"); + assert_eq!(token.offset_from, 18); + assert_eq!(token.offset_to, 21); + assert_eq!(token.position, 3); + assert_eq!(token.position_length, 4); + } + + #[test] + fn test_tokenizer_no_wsconst() { + let tokens = token_stream_helper("123456円🤌🏿", ""); + + assert_eq!(tokens.len(), 9); + + let token = &tokens[0]; + assert_eq!(token.text, "1"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 1); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 9); + + let token = &tokens[1]; + assert_eq!(token.text, "2"); + assert_eq!(token.offset_from, 1); + assert_eq!(token.offset_to, 2); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 9); + + let token = &tokens[2]; + assert_eq!(token.text, "3"); + assert_eq!(token.offset_from, 2); + assert_eq!(token.offset_to, 3); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 9); + + let token = &tokens[3]; + assert_eq!(token.text, "4"); + assert_eq!(token.offset_from, 3); + assert_eq!(token.offset_to, 4); + assert_eq!(token.position, 3); + assert_eq!(token.position_length, 9); + + let token = &tokens[4]; + assert_eq!(token.text, "5"); + assert_eq!(token.offset_from, 4); + assert_eq!(token.offset_to, 5); + assert_eq!(token.position, 4); + assert_eq!(token.position_length, 9); + + let token = &tokens[5]; + assert_eq!(token.text, "6"); + assert_eq!(token.offset_from, 5); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 5); + assert_eq!(token.position_length, 9); + + let token = &tokens[6]; + assert_eq!(token.text, "円"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 9); + assert_eq!(token.position, 6); + assert_eq!(token.position_length, 9); + + let token = &tokens[7]; + assert_eq!(token.text, "🤌"); + assert_eq!(token.offset_from, 9); + assert_eq!(token.offset_to, 13); + assert_eq!(token.position, 7); + assert_eq!(token.position_length, 9); + + let token = &tokens[8]; + assert_eq!(token.text, "🏿"); + assert_eq!(token.offset_from, 13); + assert_eq!(token.offset_to, 17); + assert_eq!(token.position, 8); + assert_eq!(token.position_length, 9); + } + + #[test] + fn test_tokenize_wsconst_d() { + let tokens = token_stream_helper("123456円🤌🏿", "D"); + + assert_eq!(tokens.len(), 4); + + let token = &tokens[0]; + assert_eq!(token.text, "123456"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 4); + + let token = &tokens[1]; + assert_eq!(token.text, "円"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 9); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 4); + + let token = &tokens[2]; + assert_eq!(token.text, "🤌"); + assert_eq!(token.offset_from, 9); + assert_eq!(token.offset_to, 13); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 4); + + let token = &tokens[3]; + assert_eq!(token.text, "🏿"); + assert_eq!(token.offset_from, 13); + assert_eq!(token.offset_to, 17); + assert_eq!(token.position, 3); + assert_eq!(token.position_length, 4); + } + + #[test] + fn test_tokenizer_wsconst_g() { + let tokens = token_stream_helper("123456円🤌🏿", "G"); + + assert_eq!(tokens.len(), 8); + + let token = &tokens[0]; + assert_eq!(token.text, "1"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 1); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 8); + + let token = &tokens[1]; + assert_eq!(token.text, "2"); + assert_eq!(token.offset_from, 1); + assert_eq!(token.offset_to, 2); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 8); + + let token = &tokens[2]; + assert_eq!(token.text, "3"); + assert_eq!(token.offset_from, 2); + assert_eq!(token.offset_to, 3); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 8); + + let token = &tokens[3]; + assert_eq!(token.text, "4"); + assert_eq!(token.offset_from, 3); + assert_eq!(token.offset_to, 4); + assert_eq!(token.position, 3); + assert_eq!(token.position_length, 8); + + let token = &tokens[4]; + assert_eq!(token.text, "5"); + assert_eq!(token.offset_from, 4); + assert_eq!(token.offset_to, 5); + assert_eq!(token.position, 4); + assert_eq!(token.position_length, 8); + + let token = &tokens[5]; + assert_eq!(token.text, "6"); + assert_eq!(token.offset_from, 5); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 5); + assert_eq!(token.position_length, 8); + + let token = &tokens[6]; + assert_eq!(token.text, "円"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 9); + assert_eq!(token.position, 6); + assert_eq!(token.position_length, 8); + + let token = &tokens[7]; + assert_eq!(token.text, "🤌🏿"); + assert_eq!(token.offset_from, 9); + assert_eq!(token.offset_to, 17); + assert_eq!(token.position, 7); + assert_eq!(token.position_length, 8); + } + + #[test] + fn test_tokenize_wsconst_dg() { + let tokens = token_stream_helper("123456円🤌🏿", "DG"); + + assert_eq!(tokens.len(), 3); + + let token = &tokens[0]; + assert_eq!(token.text, "123456"); + assert_eq!(token.offset_from, 0); + assert_eq!(token.offset_to, 6); + assert_eq!(token.position, 0); + assert_eq!(token.position_length, 3); + + let token = &tokens[1]; + assert_eq!(token.text, "円"); + assert_eq!(token.offset_from, 6); + assert_eq!(token.offset_to, 9); + assert_eq!(token.position, 1); + assert_eq!(token.position_length, 3); + + let token = &tokens[2]; + assert_eq!(token.text, "🤌🏿"); + assert_eq!(token.offset_from, 9); + assert_eq!(token.offset_to, 17); + assert_eq!(token.position, 2); + assert_eq!(token.position_length, 3); + } +} diff --git a/vaporetto_tantivy/test_model/model.zst b/vaporetto_tantivy/test_model/model.zst new file mode 100644 index 00000000..e51157d3 Binary files /dev/null and b/vaporetto_tantivy/test_model/model.zst differ diff --git a/vaporetto_wasm/Cargo.toml b/vaporetto_wasm/Cargo.toml index 2b7d7de2..99a7afad 100644 --- a/vaporetto_wasm/Cargo.toml +++ b/vaporetto_wasm/Cargo.toml @@ -12,3 +12,9 @@ vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0 vaporetto_rules = { path = "../vaporetto_rules" } # MIT or Apache-2.0 wasm-bindgen = "0.2.75" # MIT or Apache-2.0 ruzstd = "0.2.4" # MIT +wee_alloc = "0.4.5" # MPL-2.0 + +[profile.release] +opt-level = "z" +codegen-units = 1 +lto = true diff --git a/vaporetto_wasm/README.md b/vaporetto_wasm/README.md index 7fde2175..9d7f4fde 100644 --- a/vaporetto_wasm/README.md +++ b/vaporetto_wasm/README.md @@ -1,19 +1,34 @@ # WebAssembly example of Vaporetto -1. Build a model file: - ``` - # jp-0.4.7-5.mod is a model file distributed by KyTea. - cargo run --release -p convert_kytea_model -- --model-in ./jp-0.4.7-5.mod --model-out ../model/model.zstd - ``` +## How to build? -2. Build a web assembly: - ``` - % wasm-pack build --release --target web - ``` +1. Build a model file following the [documentation](../README.md). + +2. Build a JS file containing a web assembly using `build_portable_js.py`. + This script requires a model file, an identifier, and an output path. -3. Launch the server: + The identifier must consist of alphanumeric characters and underscores. ``` - % python3 -m http.server 8000 + ./build_portable_js.py --model --identifier --output ``` -4. Open http://localhost:8000/www +3. You can use the generated JS file like the follwing code: + ```html + + + + + + + + + + + ``` diff --git a/vaporetto_wasm/build_portable_js.py b/vaporetto_wasm/build_portable_js.py new file mode 100755 index 00000000..98459f4b --- /dev/null +++ b/vaporetto_wasm/build_portable_js.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 + +import argparse +import base64 +import os +import subprocess + + +def _parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model', required=True, help='A path to the model file') + parser.add_argument( + '--identifier', required=True, help='An identifier that is used to the function name' + ) + parser.add_argument('--output', required=True, help='A path to the generated file') + return parser.parse_args() + + +if __name__ == '__main__': + args = _parse_args() + + working_dir = os.path.dirname(os.path.abspath('__file__')) + + # Builds a wasm with a model file. + model_path = os.path.abspath(args.model) + env = os.environ.copy() + env['VAPORETTO_MODEL_PATH'] = model_path + subprocess.run( + ['wasm-pack', 'build', '--release', '--target', 'no-modules'], + cwd=working_dir, + env=env, + ) + + # Converts the wasm to the base64 string. + wasm_path = os.path.join(working_dir, 'pkg/vaporetto_wasm_bg.wasm') + with open(wasm_path, 'rb') as fp: + wasm_data = fp.read() + wasm_data_b64 = base64.b64encode(wasm_data).decode() + + # Reads the glue js file. + js_path = os.path.join(working_dir, 'pkg/vaporetto_wasm.js') + with open(js_path, 'rt') as fp: + js_data = fp.read() + + # Generates a unified js file. + with open(args.output, 'wt') as fp: + print( + js_data.replace('wasm_bindgen', f'__vaporetto_{args.identifier}_wbg'), + file=fp, + ) + print(f'async function vaporetto_{args.identifier}(){{', file=fp) + print(f' const data = "data:application/wasm;base64,{wasm_data_b64}";', file=fp) + print(f' await __vaporetto_{args.identifier}_wbg(fetch(data));', file=fp) + print(f' return __vaporetto_{args.identifier}_wbg.Vaporetto;', file=fp) + print('}', file=fp) diff --git a/vaporetto_wasm/src/lib.rs b/vaporetto_wasm/src/lib.rs index e7a9189d..32a74ff2 100644 --- a/vaporetto_wasm/src/lib.rs +++ b/vaporetto_wasm/src/lib.rs @@ -4,68 +4,125 @@ use js_sys::{Array, Object}; use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter}, - SentenceFilter, + string_filters::KyteaFullwidthFilter, + SentenceFilter, StringFilter, }; use wasm_bindgen::{prelude::*, JsValue}; +#[global_allocator] +static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT; + #[wasm_bindgen] pub struct Vaporetto { predictor: Predictor, + fullwidth_filter: KyteaFullwidthFilter, post_filters: Vec>, } #[wasm_bindgen] impl Vaporetto { #[wasm_bindgen] - pub fn new() -> Self { - let mut f = Cursor::new(include_bytes!("../../model/model.zstd")); + pub fn new(filters: &str) -> Self { + let mut f = Cursor::new(include_bytes!(env!("VAPORETTO_MODEL_PATH"))); let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap(); let mut buff = vec![]; decoder.read_to_end(&mut buff).unwrap(); let model = Model::read(&mut buff.as_slice()).unwrap(); - let predictor = Predictor::new(model); - let post_filters: Vec> = vec![ - Box::new(ConcatGraphemeClustersFilter::new()), - Box::new(KyteaWsConstFilter::new(CharacterType::Digit)), - ]; + let predictor = Predictor::new(model, false).unwrap(); + let post_filters: Vec<_> = filters + .chars() + .map(|c| { + let b: Box = match c { + 'D' => Box::new(KyteaWsConstFilter::new(CharacterType::Digit)), + 'R' => Box::new(KyteaWsConstFilter::new(CharacterType::Roman)), + 'H' => Box::new(KyteaWsConstFilter::new(CharacterType::Hiragana)), + 'T' => Box::new(KyteaWsConstFilter::new(CharacterType::Katakana)), + 'K' => Box::new(KyteaWsConstFilter::new(CharacterType::Kanji)), + 'O' => Box::new(KyteaWsConstFilter::new(CharacterType::Other)), + 'G' => Box::new(ConcatGraphemeClustersFilter::new()), + _ => panic!("invalid filter: {}", c), + }; + b + }) + .collect(); Self { predictor, + fullwidth_filter: KyteaFullwidthFilter::new(), post_filters, } } #[wasm_bindgen] - pub fn predict_partial(&self, text: &str, start: usize, end: usize) -> Object { - let s = if let Ok(s) = Sentence::from_raw(text) { + pub fn tokenize(&self, text: &str) -> Object { + let result = Array::new(); + let mut s = if let Ok(s) = Sentence::from_raw(text) { s } else { - return JsValue::NULL.into(); + return result.into(); }; - if start >= end { - return JsValue::NULL.into(); + let norm = self.fullwidth_filter.filter(text); + let s_norm = if let Ok(s) = Sentence::from_raw(norm) { + s + } else { + return result.into(); + }; + let s_norm = self.predictor.predict(s_norm); + s.boundaries_mut().clone_from_slice(s_norm.boundaries()); + let s = self + .post_filters + .iter() + .fold(s, |s, filter| filter.filter(s)); + + if let Ok(tokens) = s.to_tokenized_vec() { + for token in tokens { + result.push(&JsValue::from_str(token.surface)); + } } - let s = self.predictor.predict_partial_with_score(s, start..end); + result.into() + } + + #[wasm_bindgen] + pub fn predict(&self, text: &str) -> Object { + let result = Array::new(); + let text = self.fullwidth_filter.filter(text); + let s = if let Ok(s) = Sentence::from_raw(text) { + s + } else { + return result.into(); + }; + let s = self.predictor.predict(s); let s = self .post_filters .iter() .fold(s, |s, filter| filter.filter(s)); + for &b in s.boundaries() { + result.push(&JsValue::from_bool(b == BoundaryType::WordBoundary)); + } + result.into() + } + + #[wasm_bindgen] + pub fn predict_with_score(&self, text: &str) -> Object { let result = Array::new(); - for (&score, &b) in s.boundary_scores().unwrap()[start..end] + let text = self.fullwidth_filter.filter(text); + let s = if let Ok(s) = Sentence::from_raw(text) { + s + } else { + return result.into(); + }; + let s = self.predictor.predict_with_score(s); + let s = self + .post_filters .iter() - .zip(&s.boundaries()[start..end]) - { + .fold(s, |s, filter| filter.filter(s)); + + for (&score, &b) in s.boundary_scores().iter().zip(s.boundaries()) { let boundary = Array::new(); - boundary.push(&JsValue::from_bool(b == BoundaryType::WordBoundary)); - boundary.push(&JsValue::from_f64(score)); + boundary.push(&(b == BoundaryType::WordBoundary).into()); + boundary.push(&score.into()); result.push(&boundary); } result.into() } } - -impl Default for Vaporetto { - fn default() -> Self { - Self::new() - } -} diff --git a/vaporetto_wasm/www/index.html b/vaporetto_wasm/www/index.html index a6a43b8c..8434232c 100644 --- a/vaporetto_wasm/www/index.html +++ b/vaporetto_wasm/www/index.html @@ -2,9 +2,10 @@ - Vaporetto Real-time Tokenization + Vaporetto Demo - + +
@@ -17,9 +18,8 @@
Output:
-
+

             
-
Loading...
diff --git a/vaporetto_wasm/www/index.js b/vaporetto_wasm/www/index.js index e7f0deb2..47496fe1 100644 --- a/vaporetto_wasm/www/index.js +++ b/vaporetto_wasm/www/index.js @@ -1,152 +1,30 @@ -import init from '../pkg/vaporetto_wasm.js'; -import * as wasm from '../pkg/vaporetto_wasm.js'; - -const loading = document.getElementById("loading"); -loading.style.display = "block"; - -function run() { - const predictor = wasm.Vaporetto.new(); - - loading.style.display = "none"; - - function createTextSpan(text) { - const span = document.createElement("span"); - const textnode = document.createTextNode(text); - span.appendChild(textnode); - return span; +function createTextSpan(text, isBoundary, score) { + const span = document.createElement("span"); + const textnode = document.createTextNode(text); + span.appendChild(textnode); + if (isBoundary) { + span.style.borderLeft = "5pt solid rgba(0, 0, 0, " + Math.atan(score / 2) + ")"; } + return span; +} - function replace_text(elem, prev_text, text, range_from, range_to, boundaries, window_size) { - const prev_boundary_start = Math.max(range_from[0] - window_size, 0); - const prev_boundary_end = Math.min(range_from[1] + window_size - 1, prev_text.length - 1); - const node_end_idx = prev_boundary_end + 1; - let node_end = elem.childNodes[0]; - if (prev_text.length != 0) { - node_end = elem.childNodes[node_end_idx]; - if (range_from[0] == 0) { - node_end.previousSibling.remove(); - } - for (let i = prev_boundary_end - prev_boundary_start; i > 0; --i) { - node_end.previousSibling.remove(); - } - } - const next_boundary_start = Math.max(range_to[0] - window_size, 0); - const next_boundary_end = Math.min(range_to[1] + window_size - 1, text.length - 1); - if (text.length != 0) { - if (range_to[0] == 0) { - node_end.before(createTextSpan(text[next_boundary_start])); - } - for (let i = 0; i < next_boundary_end - next_boundary_start; ++i) { - const elem = createTextSpan(text[next_boundary_start + i + 1]); - if (boundaries[i][0]) { - elem.style.borderLeft = '5pt solid rgba(0, 0, 0, ' + Math.atan(boundaries[i][1] / 2) + ')'; - } - node_end.before(elem); - } - } - } - - const input_text = document.getElementById('input_text'); - input_text.value = ""; - - const window_size = 3; - - let input_data = null; - let prev_range = [0, 0]; - let prev_chars = []; - let chars_pos_map = [0]; - - let composition_start = null; - input_text.addEventListener('compositionstart', function (e) { - composition_start = chars_pos_map[e.target.selectionStart]; - }); - - input_text.addEventListener('compositionend', function (e) { - composition_start = null; - }); - - input_text.addEventListener('beforeinput', function (e) { - input_data = e.data; - if (composition_start != null) { - prev_range = [composition_start, chars_pos_map[e.target.selectionEnd]]; - } else { - prev_range = [chars_pos_map[e.target.selectionStart], chars_pos_map[e.target.selectionEnd]]; - } - }); - - input_text.addEventListener('input', function (e) { - const t0 = performance.now(); +vaporetto_bccwj_suw_small().then((Vaporetto) => { + const vaporetto_suw = Vaporetto.new("DG"); - const cur_text = e.target.value; - const cur_chars = Array.from(cur_text); - chars_pos_map = new Array(cur_text.length); - let utf16_pos = 0; - for (let i = 0; i < cur_chars.length; ++i) { - chars_pos_map[utf16_pos] = i; - utf16_pos += cur_chars[i].length; + input_text.addEventListener("input", (e) => { + const text = input_text.value; + const scores = vaporetto_suw.predict_with_score(text); + let i = -1; + while (tokenized.firstChild) { + tokenized.removeChild(tokenized.firstChild); } - chars_pos_map.push(cur_chars.length); - - let range_from = null; - let range_to = null; - switch (e.inputType) { - case 'insertText': - case 'insertLineBreak': - case 'insertParagraph': - case 'insertFromPaste': - case 'insertCompositionText': - range_from = prev_range; - range_to = [prev_range[0], prev_range[1] + cur_chars.length - prev_chars.length]; - break; - case 'deleteWordBackward': - case 'deleteWordForward': - case 'deleteSoftLineBackward': - case 'deleteSoftLineForward': - case 'deleteEntireSoftLine': - case 'deleteHardLineBackward': - case 'deleteHardLineForward': - case 'deleteByCut': - case 'deleteContent': - case 'deleteContentBackward': - case 'deleteContentForward': - const start = chars_pos_map[e.target.selectionStart]; - const right_length = cur_chars.length - start; - const prev_end = prev_chars.length - right_length; - range_from = [start, prev_end]; - range_to = [start, start]; - break; - default: - range_from = [0, prev_chars.length]; - range_to = [0, cur_chars.length]; + for (let c of text) { + if (i >= 0) { + tokenized.appendChild(createTextSpan(c, scores[i][0], scores[i][1] / 10000)); + } else { + tokenized.appendChild(createTextSpan(c, false, 0)); + } + ++i; } - - const tokenized = document.getElementById("tokenized"); - - const predict_chars_start = Math.max(range_to[0] - window_size * 2 + 1, 0); - const predict_chars_end = Math.min(range_to[1] + window_size * 2 - 1, cur_chars.length); - const predict_chars = cur_chars.slice(predict_chars_start, predict_chars_end); - - const boundary_start = Math.max(range_to[0] - window_size, 0); - const boundary_end = Math.min(range_to[1] + window_size - 1, cur_chars.length - 1); - - const predict_boundary_start = boundary_start - predict_chars_start; - const predict_boundary_end = boundary_end - predict_chars_start; - - const boundaries = predictor.predict_partial(predict_chars.join(""), predict_boundary_start, predict_boundary_end); - - console.log("input with window:", predict_chars); - console.log("prediction range:", [predict_boundary_start, predict_boundary_end]); - console.log("boundaries:", boundaries); - - replace_text(tokenized, prev_chars, cur_chars, range_from, range_to, boundaries, window_size); - - const t1 = performance.now(); - - console.log("Elapsed:", t1 - t0, "[ms]"); - console.log("-----"); - - prev_chars = cur_chars; }); -} - -init().then(run); +});