From 261e3f2bca7987253d63b77f5ab54a31cf5fe799 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Fri, 1 Apr 2022 19:54:04 +0900 Subject: [PATCH] Maintain dict scores in model by arrays instead of tuples (#29) * Change dictionary score to character-wise score * Update README * update README * fix readme * fix * fix * Apply suggestions from code review Co-authored-by: Shunsuke Kanda Co-authored-by: Shunsuke Kanda --- README-ja.md | 87 +++++++++++++++++------------------- README.md | 83 +++++++++++++++++----------------- manipulate_model/src/main.rs | 21 ++++----- vaporetto/src/char_scorer.rs | 12 +---- vaporetto/src/dict_model.rs | 44 ++++++++---------- vaporetto/src/kytea_model.rs | 15 ++++--- vaporetto/src/predictor.rs | 86 +++++++---------------------------- vaporetto/src/trainer.rs | 8 +++- 8 files changed, 140 insertions(+), 216 deletions(-) diff --git a/README-ja.md b/README-ja.md index 8dc567e9..95b3060c 100644 --- a/README-ja.md +++ b/README-ja.md @@ -113,52 +113,51 @@ Vaporetto は2種類のコーパス、すなわちフルアノテーションコ ### モデルの編集 -時々、モデルが期待とは異なる結果を出力することがあるでしょう。 -例えば、以下のコマンドで `メロンパン` は2つのトークンに分割されます。 +モデルが期待とは異なる結果を出力することがあるでしょう。 +例えば、以下のコマンドで `外国人参政権` は誤ったトークンに分割されます。 `--scores` オプションを使って、各文字間のスコアを出力します。 ``` -% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize.model.zst -朝食 は メロン パン 1 個 だっ た -0:朝食 -15398 -1:食は 24623 -2:はメ 30261 -3:メロ -26885 -4:ロン -38896 -5:ンパ 8162 -6:パン -23416 -7:ン1 23513 -8:1個 18435 -9:個だ 24964 -10:だっ -15065 -11:った 14178 +% echo '外国人参政権と政権交代' | cargo run --release -p predict -- --scores --model path/to/bccwj-suw+unidic.model.zst +外国 人 参 政権 と 政権 交代 +0:外国 -11785 +1:国人 16634 +2:人参 5450 +3:参政 4480 +4:政権 -3697 +5:権と 17702 +6:と政 18699 +7:政権 -12742 +8:権交 14578 +9:交代 -7658 ``` -`メロンパン` を単一のトークンに連結するには、以下の手順でモデルを編集し、 `ンパ` のスコアを負にします。 +正しくは `外国 人 参政 権` です。 +`外国人参政権` を正しいトークンに分割するには、以下の手順でモデルを編集し、 `参政権` のスコアの符号を反転させます。 1. 以下のコマンドで辞書を吐き出します。 ``` - % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --dump-dict path/to/dictionary.csv + % cargo run --release -p manipulate_model -- --model-in path/to/bccwj-suw+unidic.model.zst --dump-dict path/to/dictionary.csv ``` 2. 辞書を編集します。 - 辞書は CSV ファイルです。各行には単語と、対応する重みとコメントが以下の順で含まれています。 + 辞書は CSV ファイルです。各行には文字列パターン、対応する重み配列、コメントが以下のように含まれています。 - * `right_weight` - 単語が境界の右側に見つかった際に追加される重み。 - * `inside_weight` - 単語が境界に重なっている際に追加される重み。 - * `left_weight` - 単語が境界の左側に見つかった際に追加される重み。 + * `word` - 文字列パターン(主に単語) + * `weights` - 重み配列。入力文字列に対象の文字列パターンが含まれている場合、見つかったパターンの範囲の文字境界に対してこれらの重みが加算されます。 * `comment` - 挙動に影響しないコメント Vaporetto は、重みの合計が正の値になった際にテキストを分割するので、以下のように新しいエントリを追加します。 ```diff - メロレオストーシス,6944,-2553,5319, - メロン,8924,-10861,7081, - +メロンパン,0,-100000,0,melon🍈 bread🍞 in English. - メロン果実,4168,-1165,3558, - メロヴィング,6999,-15413,7583, + 参撾,3167 -6074 3790, + 参政,3167 -6074 3790, + +参政権,0 -10000 10000 0,参政/権 + 参朝,3167 -6074 3790, + 参校,3167 -6074 3790, ``` - この場合、境界が `メロンパン` の内側だった際に `-100000` が追加されます。 + この場合、 `参` と `政` の間に `-10000` が、 `政` と `権` の間に `10000` が加算されます。 + パターンの両端では `0` が指定されているため、スコアは加算されません。 Vaporetto は重みの合計値に 32-bit 整数を利用しているため、オーバーフローに気をつけてください。 @@ -167,25 +166,23 @@ Vaporetto は2種類のコーパス、すなわちフルアノテーションコ 3. モデルファイルの重みを置換します。 ``` - % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --replace-dict path/to/dictionary.csv --model-out path/to/jp-0.4.7-5-tokenize-new.model.zst + % cargo run --release -p manipulate_model -- --model-in path/to/bccwj-suw+unidic.model.zst --replace-dict path/to/dictionary.csv --model-out path/to/bccwj-suw+unidic-new.model.zst ``` -これで `メロンパン` が単一のトークンに分割されます。 -``` -% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize-new.model.zst -朝食 は メロンパン 1 個 だっ た -0:朝食 -15398 -1:食は 24623 -2:はメ 30261 -3:メロ -126885 -4:ロン -138896 -5:ンパ -91838 -6:パン -123416 -7:ン1 23513 -8:1個 18435 -9:個だ 24964 -10:だっ -15065 -11:った 14178 +これで `外国人参政権` が正しいトークンに分割されます。 +``` +% echo '外国人参政権と政権交代' | cargo run --release -p predict -- --scores --model path/to/bccwj-suw+unidic-new.model.zst +外国 人 参政 権 と 政権 交代 +0:外国 -11785 +1:国人 16634 +2:人参 5450 +3:参政 -5520 +4:政権 6303 +5:権と 17702 +6:と政 18699 +7:政権 -12742 +8:権交 14578 +9:交代 -7658 ``` ### 品詞推定 diff --git a/README.md b/README.md index 7607d604..96bd9f08 100644 --- a/README.md +++ b/README.md @@ -115,51 +115,50 @@ You can specify all arguments above multiple times. ### Model Manipulation Sometimes, your model will output different results than what you expect. -For example, `メロンパン` is split into two tokens in the following command. +For example, `外国人参政権` is split into wrong tokens in the following command. We use `--scores` option to show the score of each character boundary: ``` -% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize.model.zst -朝食 は メロン パン 1 個 だっ た -0:朝食 -15398 -1:食は 24623 -2:はメ 30261 -3:メロ -26885 -4:ロン -38896 -5:ンパ 8162 -6:パン -23416 -7:ン1 23513 -8:1個 18435 -9:個だ 24964 -10:だっ -15065 -11:った 14178 +% echo '外国人参政権と政権交代' | cargo run --release -p predict -- --scores --model path/to/bccwj-suw+unidic.model.zst +外国 人 参 政権 と 政権 交代 +0:外国 -11785 +1:国人 16634 +2:人参 5450 +3:参政 4480 +4:政権 -3697 +5:権と 17702 +6:と政 18699 +7:政権 -12742 +8:権交 14578 +9:交代 -7658 ``` -To concatenate `メロンパン` into a single token, manipulate the model in the following steps so that the score of `ンパ` becomes negative: +The correct is `外国 人 参政 権`. +To split `外国人参政権` into correct tokens, manipulate the model in the following steps so that the sign of score of `参政権` becomes inverted: 1. Dump a dictionary by the following command: ``` - % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --dump-dict path/to/dictionary.csv + % cargo run --release -p manipulate_model -- --model-in path/to/bccwj-suw+unidic.model.zst --dump-dict path/to/dictionary.csv ``` 2. Edit the dictionary. - The dictionary is a csv file. Each row contains a word, corresponding weights, and a comment in the following order: + The dictionary is a csv file. Each row contains a string pattern, corresponding weight array, and a comment in the following order: - * `right_weight` - A weight that is added when the word is found to the right of the boundary. - * `inside_weight` - A weight that is added when the word is overlapped on the boundary. - * `left_weight` - A weight that is added when the word is found to the left of the boundary. + * `word` - A string pattern (usually, a word) + * `weights` - A weight array. When the string pattern is contained in the input string, these weights are added to character boundaries of the range of the found pattern. * `comment` - A comment that does not affect the behaviour. Vaporetto splits a text when the total weight of the boundary is a positive number, so we add a new entry as follows: ```diff - メロレオストーシス,6944,-2553,5319, - メロン,8924,-10861,7081, - +メロンパン,0,-100000,0,melon🍈 bread🍞 in English. - メロン果実,4168,-1165,3558, - メロヴィング,6999,-15413,7583, + 参撾,3167 -6074 3790, + 参政,3167 -6074 3790, + +参政権,0 -10000 10000 0,参政/権 + 参朝,3167 -6074 3790, + 参校,3167 -6074 3790, ``` - In this case, `-100000` will be added when the boundary is inside of the word `メロンパン`. + In this case, `-10000` will be added between `参` and `政`, and `10000` will be added between `政` and `権`. + Because `0` is specified at both ends of the pattern, no scores are added at those positions. Note that Vaporetto uses 32-bit integers for the total weight, so you have to be careful about overflow. @@ -168,25 +167,23 @@ To concatenate `メロンパン` into a single token, manipulate the model in th 3. Replaces weight data of a model file ``` - % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --replace-dict path/to/dictionary.csv --model-out path/to/jp-0.4.7-5-tokenize-new.model.zst + % cargo run --release -p manipulate_model -- --model-in path/to/bccwj-suw+unidic.model.zst --replace-dict path/to/dictionary.csv --model-out path/to/bccwj-suw+unidic-new.model.zst ``` -Now `メロンパン` is split into a single token. +Now `外国人参政権` is split into correct tokens. ``` -% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize-new.model.zst -朝食 は メロンパン 1 個 だっ た -0:朝食 -15398 -1:食は 24623 -2:はメ 30261 -3:メロ -126885 -4:ロン -138896 -5:ンパ -91838 -6:パン -123416 -7:ン1 23513 -8:1個 18435 -9:個だ 24964 -10:だっ -15065 -11:った 14178 +% echo '外国人参政権と政権交代' | cargo run --release -p predict -- --scores --model path/to/bccwj-suw+unidic-new.model.zst +外国 人 参政 権 と 政権 交代 +0:外国 -11785 +1:国人 16634 +2:人参 5450 +3:参政 -5520 +4:政権 6303 +5:権と 17702 +6:と政 18699 +7:政権 -12742 +8:権交 14578 +9:交代 -7658 ``` ### POS tagging diff --git a/manipulate_model/src/main.rs b/manipulate_model/src/main.rs index 6050409f..25445237 100644 --- a/manipulate_model/src/main.rs +++ b/manipulate_model/src/main.rs @@ -31,9 +31,7 @@ struct Args { #[derive(Deserialize, Serialize)] struct WordWeightRecordFlatten { word: String, - right: i32, - inside: i32, - left: i32, + weights: String, comment: String, } @@ -49,11 +47,10 @@ fn main() -> Result<(), Box> { let file = fs::File::create(path)?; let mut wtr = csv::Writer::from_writer(file); for data in model.dictionary() { + let str_weights: Vec<_> = data.get_weights().iter().map(|w| w.to_string()).collect(); wtr.serialize(WordWeightRecordFlatten { word: data.get_word().to_string(), - right: data.get_right_weight(), - inside: data.get_inside_weight(), - left: data.get_left_weight(), + weights: str_weights.join(" "), comment: data.get_comment().to_string(), })?; } @@ -66,13 +63,11 @@ fn main() -> Result<(), Box> { let mut dict = vec![]; for result in rdr.deserialize() { let record: WordWeightRecordFlatten = result?; - dict.push(WordWeightRecord::new( - record.word, - record.right, - record.inside, - record.left, - record.comment, - )); + let mut weights = vec![]; + for w in record.weights.split(' ') { + weights.push(w.parse()?); + } + dict.push(WordWeightRecord::new(record.word, weights, record.comment)?); } model.replace_dictionary(dict); } diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 688955a1..1d5eb981 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -270,16 +270,12 @@ impl CharScorer { } for d in dict.dict { let word_len = d.word.chars().count(); - let mut weight = Vec::with_capacity(word_len + 1); - weight.push(d.weights.right); - weight.resize(word_len, d.weights.inside); - weight.push(d.weights.left); let word_len = i16::try_from(word_len).map_err(|_| { VaporettoError::invalid_model( "words must be shorter than or equal to 32767 characters", ) })?; - let weight = PositionalWeight::new(-word_len - 1, weight); + let weight = PositionalWeight::new(-word_len - 1, d.weights); weight_merger.add(&d.word, weight); } @@ -376,16 +372,12 @@ impl CharScorerWithTags { } for d in dict.dict { let word_len = d.word.chars().count(); - let mut weight = Vec::with_capacity(word_len + 1); - weight.push(d.weights.right); - weight.resize(word_len, d.weights.inside); - weight.push(d.weights.left); let word_len = i16::try_from(word_len).map_err(|_| { VaporettoError::invalid_model( "words must be shorter than or equal to 32767 characters", ) })?; - let weight = WeightSet::boundary_weight(-word_len, weight); + let weight = WeightSet::boundary_weight(-word_len, d.weights); weight_merger.add(&d.word, weight); } for d in tag_left_model.data { diff --git a/vaporetto/src/dict_model.rs b/vaporetto/src/dict_model.rs index 61c7429f..65cbb030 100644 --- a/vaporetto/src/dict_model.rs +++ b/vaporetto/src/dict_model.rs @@ -3,7 +3,9 @@ use alloc::vec::Vec; use bincode::{Decode, Encode}; -#[derive(Clone, Copy, Default, Decode, Encode)] +use crate::errors::{Result, VaporettoError}; + +#[derive(Clone, Copy, Default)] pub struct DictWeight { pub right: i32, pub inside: i32, @@ -14,7 +16,7 @@ pub struct DictWeight { #[derive(Clone, Decode, Encode)] pub struct WordWeightRecord { pub(crate) word: String, - pub(crate) weights: DictWeight, + pub(crate) weights: Vec, pub(crate) comment: String, } @@ -24,24 +26,24 @@ impl WordWeightRecord { /// # Arguments /// /// * `word` - A word. - /// * `right` - A weight of the boundary when the word is found at right. - /// * `inside` - A weight of the boundary when the word is overlapped on the boundary. - /// * `left` - A weight of the boundary when the word is found at left. + /// * `weights` - A weight of boundaries. /// * `comment` - A comment that does not affect the behaviour. /// /// # Returns /// /// A new record. - pub const fn new(word: String, right: i32, inside: i32, left: i32, comment: String) -> Self { - Self { + pub fn new(word: String, weights: Vec, comment: String) -> Result { + if weights.len() != word.chars().count() + 1 { + return Err(VaporettoError::invalid_argument( + "weights", + "does not match the length of the `word`", + )); + } + Ok(Self { word, - weights: DictWeight { - right, - inside, - left, - }, + weights, comment, - } + }) } /// Gets a reference to the word. @@ -49,19 +51,9 @@ impl WordWeightRecord { &self.word } - /// Gets a `right` weight. - pub const fn get_right_weight(&self) -> i32 { - self.weights.right - } - - /// Gets a `inside` weight. - pub const fn get_inside_weight(&self) -> i32 { - self.weights.inside - } - - /// Gets a `left` weight. - pub const fn get_left_weight(&self) -> i32 { - self.weights.left + /// Gets weights. + pub fn get_weights(&self) -> &[i32] { + &self.weights } /// Gets a reference to the comment. diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 718c116d..61a3af76 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -473,16 +473,19 @@ impl TryFrom for Model { let mut dict = vec![]; if let Some(kytea_dict) = model.dict { for (w, data) in kytea_dict.dump_items() { - let word_len = std::cmp::min(w.len(), config.dict_n as usize) - 1; - let mut weights = DictWeight::default(); + let idx = std::cmp::min(w.len(), config.dict_n as usize) - 1; + let mut dict_weight = DictWeight::default(); for j in 0..kytea_dict.n_dicts as usize { if data.in_dict >> j & 1 == 1 { - let offset = 3 * config.dict_n as usize * j + 3 * word_len; - weights.right += i32::from(feature_lookup.dict_vec[offset]); - weights.inside += i32::from(feature_lookup.dict_vec[offset + 1]); - weights.left += i32::from(feature_lookup.dict_vec[offset + 2]); + let offset = 3 * config.dict_n as usize * j + 3 * idx; + dict_weight.right += i32::from(feature_lookup.dict_vec[offset]); + dict_weight.inside += i32::from(feature_lookup.dict_vec[offset + 1]); + dict_weight.left += i32::from(feature_lookup.dict_vec[offset + 2]); } } + let mut weights = vec![dict_weight.inside; w.len() + 1]; + *weights.first_mut().unwrap() = dict_weight.right; + *weights.last_mut().unwrap() = dict_weight.left; dict.push(WordWeightRecord { word: w.into_iter().collect(), weights, diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 3df28fa0..c97910de 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -338,7 +338,7 @@ mod tests { use alloc::string::ToString; - use crate::dict_model::{DictModel, DictWeight, WordWeightRecord}; + use crate::dict_model::{DictModel, WordWeightRecord}; use crate::ngram_model::{NgramData, NgramModel}; use crate::sentence::CharacterType::*; use crate::tag_model::TagModel; @@ -424,29 +424,17 @@ mod tests { dict: vec![ WordWeightRecord { word: "全世界".to_string(), - weights: DictWeight { - right: 43, - inside: 44, - left: 45, - }, + weights: vec![43, 44, 44, 45], comment: "".to_string(), }, WordWeightRecord { word: "世界".to_string(), - weights: DictWeight { - right: 43, - inside: 44, - left: 45, - }, + weights: vec![43, 44, 45], comment: "".to_string(), }, WordWeightRecord { word: "世".to_string(), - weights: DictWeight { - right: 40, - inside: 41, - left: 42, - }, + weights: vec![40, 42], comment: "".to_string(), }, ], @@ -534,29 +522,17 @@ mod tests { dict: vec![ WordWeightRecord { word: "全世界".to_string(), - weights: DictWeight { - right: 44, - inside: 45, - left: 46, - }, + weights: vec![44, 45, 45, 46], comment: "".to_string(), }, WordWeightRecord { word: "世界".to_string(), - weights: DictWeight { - right: 41, - inside: 42, - left: 43, - }, + weights: vec![41, 42, 43], comment: "".to_string(), }, WordWeightRecord { word: "世".to_string(), - weights: DictWeight { - right: 38, - inside: 39, - left: 40, - }, + weights: vec![38, 40], comment: "".to_string(), }, ], @@ -644,29 +620,17 @@ mod tests { dict: vec![ WordWeightRecord { word: "国民".to_string(), - weights: DictWeight { - right: 38, - inside: 39, - left: 40, - }, + weights: vec![38, 39, 40], comment: "".to_string(), }, WordWeightRecord { word: "世界".to_string(), - weights: DictWeight { - right: 41, - inside: 42, - left: 43, - }, + weights: vec![41, 42, 43], comment: "".to_string(), }, WordWeightRecord { word: "世".to_string(), - weights: DictWeight { - right: 44, - inside: 45, - left: 46, - }, + weights: vec![44, 46], comment: "".to_string(), }, ], @@ -762,47 +726,27 @@ mod tests { dict: vec![ WordWeightRecord { word: "全世界".to_string(), - weights: DictWeight { - right: 43, - inside: 44, - left: 45, - }, + weights: vec![43, 44, 44, 45], comment: "".to_string(), }, WordWeightRecord { word: "世界".to_string(), - weights: DictWeight { - right: 43, - inside: 44, - left: 45, - }, + weights: vec![43, 44, 45], comment: "".to_string(), }, WordWeightRecord { word: "世".to_string(), - weights: DictWeight { - right: 40, - inside: 41, - left: 42, - }, + weights: vec![40, 42], comment: "".to_string(), }, WordWeightRecord { word: "世界の国民".to_string(), - weights: DictWeight { - right: 43, - inside: 44, - left: 45, - }, + weights: vec![43, 44, 44, 44, 44, 45], comment: "".to_string(), }, WordWeightRecord { word: "は全世界".to_string(), - weights: DictWeight { - right: 43, - inside: 44, - left: 45, - }, + weights: vec![43, 44, 44, 44, 45], comment: "".to_string(), }, ], diff --git a/vaporetto/src/trainer.rs b/vaporetto/src/trainer.rs index f07fd2ef..3cff7d6d 100644 --- a/vaporetto/src/trainer.rs +++ b/vaporetto/src/trainer.rs @@ -379,10 +379,14 @@ impl<'a> Trainer<'a> { self.dictionary .into_iter() .map(|word| { - let idx = word.chars().count().min(dict_weights.len()) - 1; + let word_len = word.chars().count(); + let idx = word_len.min(dict_weights.len()) - 1; + let mut weights = vec![dict_weights[idx].inside; word_len + 1]; + *weights.first_mut().unwrap() = dict_weights[idx].right; + *weights.last_mut().unwrap() = dict_weights[idx].left; WordWeightRecord { word, - weights: dict_weights[idx], + weights, comment: "".to_string(), } })