From dbbe79461614293c62956a1ccbd24603b05bd119 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Tue, 5 Sep 2023 16:58:47 +0900 Subject: [PATCH] Add --tag-scores option to show tag scores (#107) * Add --tag-scores option * Update README * fix --- README-ja.md | 17 ++++++-- README.md | 13 +++++- predict/src/main.rs | 35 +++++++++++++++- vaporetto/src/predictor.rs | 86 ++++++++++++++++++++++++++------------ vaporetto/src/sentence.rs | 41 ++++++++++++++++++ 5 files changed, 159 insertions(+), 33 deletions(-) diff --git a/README-ja.md b/README-ja.md index 57f00bc..8e6c0aa 100644 --- a/README-ja.md +++ b/README-ja.md @@ -205,9 +205,9 @@ Vaporetto は2種類のコーパス(フルアノテーションコーパスと 9:交代 -5794 ``` -### 品詞推定 +### タグ予測 -Vaporettoは実験的にタグ推定(品詞推定や読み推定)に対応しています。 +Vaporettoは実験的にタグ予測(品詞予測や読み予測)に対応しています。 タグを学習するには、以下のように、データセットの各トークンに続けてスラッシュとタグを追加します。 @@ -226,7 +226,18 @@ Vaporettoは実験的にタグ推定(品詞推定や読み推定)に対応 データセットにタグが含まれる場合、 `train` コマンドは自動的にそれらを学習します。 -推定時は、デフォルトではタグは推定されないため、必要に応じで `predict` コマンドに `--predict-tags` 引数を指定してください。 +予測時は、デフォルトではタグは予測されないため、必要に応じて `predict` コマンドに `--predict-tags` 引数を指定してください。 + +`--tag-scores` 引数を指定すると、タグ予測の際に計算された各候補のスコアを表示できます。 +タグの候補が1つしかない場合は、スコアが0と表示されます。 + +``` +% echo "花が咲く" | cargo run --release -p predict -- --model path/to/bccwj-suw+unidic_pos+pron.model.zst --predict-tags --tag-scores +花/名詞-普通名詞-一般/ハナ が/助詞-格助詞/ガ 咲く/動詞-一般/サク +花 名詞-普通名詞-一般:18613,接尾辞-名詞的-一般:-18613 ハナ:19973,バナ:-20377,カ:-20480,ゲ:-20410 +が 助詞-接続助詞:-20408,助詞-格助詞:23543,接続詞:-25332 ガ:0 +咲く 動詞-一般:0 サク:0 +``` ## 各種トークナイザの速度比較 diff --git a/README.md b/README.md index b9c6964..05af1d7 100644 --- a/README.md +++ b/README.md @@ -211,7 +211,7 @@ Now `外国人参政権` is split into correct tokens. 9:交代 -5794 ``` -### Tagging +### Tag prediction Vaporetto experimentally supports tagging (e.g., part-of-speech and pronunciation tags). @@ -234,6 +234,17 @@ If the dataset contains tags, the `train` command automatically trains them. In prediction, tags are not predicted by default, so you have to specify the `--predict-tags` argument to the `predict` command if necessary. +If you specify the `--tag-scores` argument, the score of each candidate calculated during tag prediction is displayed. +If there is only one candidate, the score becomes 0. + +``` +% echo "花が咲く" | cargo run --release -p predict -- --model path/to/bccwj-suw+unidic_pos+pron.model.zst --predict-tags --tag-scores +花/名詞-普通名詞-一般/ハナ が/助詞-格助詞/ガ 咲く/動詞-一般/サク +花 名詞-普通名詞-一般:18613,接尾辞-名詞的-一般:-18613 ハナ:19973,バナ:-20377,カ:-20480,ゲ:-20410 +が 助詞-接続助詞:-20408,助詞-格助詞:23543,接続詞:-25332 ガ:0 +咲く 動詞-一般:0 サク:0 +``` + ## Speed Comparison of Various Tokenizers Vaporetto is 8.7 times faster than KyTea. diff --git a/predict/src/main.rs b/predict/src/main.rs index 8d3b743..485691a 100644 --- a/predict/src/main.rs +++ b/predict/src/main.rs @@ -50,10 +50,14 @@ struct Args { #[arg(long)] wsconst: Vec, - /// Prints scores. + /// Prints boundary scores. #[arg(long)] scores: bool, + /// Prints tag scores. + #[arg(long)] + tag_scores: bool, + /// Do not normalize input strings before prediction. #[arg(long)] no_norm: bool, @@ -70,6 +74,24 @@ fn print_scores(s: &Sentence, mut out: impl Write) -> Result<(), Box Result<(), Box> { + for token in s.iter_tokens() { + out.write_all(token.surface().as_bytes())?; + for cands in token.tag_candidates() { + out.write_all(b"\t")?; + for (i, (tag, score)) in cands.iter().enumerate() { + if i != 0 { + out.write_all(b",")?; + } + write!(out, "{tag}:{score}")?; + } + } + out.write_all(b"\n")?; + } + out.write_all(b"\n")?; + Ok(()) +} + fn main() -> Result<(), Box> { let args = Args::parse(); @@ -87,7 +109,10 @@ fn main() -> Result<(), Box> { eprintln!("Loading model file..."); let mut f = zstd::Decoder::new(File::open(args.model)?)?; let model = Model::read(&mut f)?; - let predictor = Predictor::new(model, args.predict_tags)?; + let mut predictor = Predictor::new(model, args.predict_tags)?; + if args.tag_scores { + predictor.store_tag_scores(true); + } let is_tty = atty::is(atty::Stream::Stdout); @@ -114,6 +139,9 @@ fn main() -> Result<(), Box> { } } out.write_all(b"\n")?; + if args.tag_scores { + print_tag_scores(&s, &mut out)?; + } if is_tty { out.flush()?; } @@ -143,6 +171,9 @@ fn main() -> Result<(), Box> { } else { out.write_all(b"\n")?; } + if args.tag_scores { + print_tag_scores(&s, &mut out)?; + } if is_tty { out.flush()?; } diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index bb2fa66..dad4d1a 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -430,7 +430,10 @@ assert_eq!( ``` " )] -pub struct Predictor(PredictorData); +pub struct Predictor { + data: PredictorData, + tag_scores: bool, +} impl Predictor { /// Creates a new predictor from the model. @@ -487,16 +490,25 @@ impl Predictor { #[cfg(feature = "tag-prediction")] tag_type_ngram_model, )?; - Ok(Self(PredictorData { - char_scorer, - type_scorer, - bias: model.0.bias, + Ok(Self { + data: PredictorData { + char_scorer, + type_scorer, + bias: model.0.bias, + + #[cfg(feature = "tag-prediction")] + tag_predictor, + #[cfg(feature = "tag-prediction")] + n_tags, + }, + tag_scores: false, + }) + } - #[cfg(feature = "tag-prediction")] - tag_predictor, - #[cfg(feature = "tag-prediction")] - n_tags, - })) + /// Stores tag scores if the given `flag` is `true`. + #[cfg(feature = "tag-prediction")] + pub fn store_tag_scores(&mut self, flag: bool) { + self.tag_scores = flag; } /// Predicts word boundaries of the given sentence. @@ -504,13 +516,14 @@ impl Predictor { pub fn predict<'a>(&'a self, sentence: &mut Sentence<'_, 'a>) { sentence.score_padding = WEIGHT_FIXED_LEN - 1; sentence.boundary_scores.clear(); - sentence - .boundary_scores - .resize(sentence.score_padding * 2 + sentence.len() - 1, self.0.bias); - if let Some(scorer) = self.0.char_scorer.as_ref() { + sentence.boundary_scores.resize( + sentence.score_padding * 2 + sentence.len() - 1, + self.data.bias, + ); + if let Some(scorer) = self.data.char_scorer.as_ref() { scorer.add_scores(sentence); } - if let Some(scorer) = self.0.type_scorer.as_ref() { + if let Some(scorer) = self.data.type_scorer.as_ref() { scorer.add_scores(sentence); } for (b, s) in sentence @@ -530,19 +543,25 @@ impl Predictor { #[cfg(feature = "tag-prediction")] pub(crate) fn predict_tags<'a>(&'a self, sentence: &mut Sentence<'_, 'a>) { let tag_predictor = self - .0 + .data .tag_predictor .as_ref() .expect("this predictor is created with predict_tags = false"); - if self.0.n_tags == 0 { + if self.data.n_tags == 0 { return; } let mut scores = vec![]; let mut range_start = Some(0); - sentence.n_tags = self.0.n_tags; + sentence.n_tags = self.data.n_tags; sentence.tags.clear(); - sentence.tags.resize(sentence.len() * self.0.n_tags, None); + sentence + .tags + .resize(sentence.len() * self.data.n_tags, None); + sentence.tag_scores.clear(); + if self.tag_scores { + sentence.tag_scores.resize(sentence.len(), None); + } for (i, &b) in sentence.boundaries.iter().enumerate() { if b == CharacterBoundary::Unknown { range_start.take(); @@ -553,7 +572,7 @@ impl Predictor { scores.clear(); scores.resize(tag_predictor.bias().len(), 0); tag_predictor.bias().add_scores(&mut scores); - if let Some(scorer) = self.0.char_scorer.as_ref() { + if let Some(scorer) = self.data.char_scorer.as_ref() { debug_assert!(i < sentence.char_pma_states.len()); // token_id is always smaller than tag_weight.len() because // tag_predictor is created to contain such values in the new() @@ -562,7 +581,7 @@ impl Predictor { scorer.add_tag_scores(*token_id, i, sentence, &mut scores); } } - if let Some(scorer) = self.0.type_scorer.as_ref() { + if let Some(scorer) = self.data.type_scorer.as_ref() { debug_assert!(i < sentence.type_pma_states.len()); // token_id is always smaller than tag_weight.len() because // tag_predictor is created to contain such values in the new() @@ -573,8 +592,12 @@ impl Predictor { } tag_predictor.predict( &scores, - &mut sentence.tags[i * self.0.n_tags..(i + 1) * self.0.n_tags], + &mut sentence.tags[i * self.data.n_tags..(i + 1) * self.data.n_tags], ); + if !sentence.tag_scores.is_empty() { + sentence.tag_scores[i].replace((&tag_predictor.tags, scores)); + scores = vec![]; + } } } range_start.replace(i + 1); @@ -586,7 +609,7 @@ impl Predictor { scores.clear(); scores.resize(tag_predictor.bias().len(), 0); tag_predictor.bias().add_scores(&mut scores); - if let Some(scorer) = self.0.char_scorer.as_ref() { + if let Some(scorer) = self.data.char_scorer.as_ref() { debug_assert!(sentence.len() <= sentence.char_pma_states.len()); // token_id is always smaller than tag_weight.len() because tag_predictor is // created to contain such values in the new() function. @@ -594,7 +617,7 @@ impl Predictor { scorer.add_tag_scores(*token_id, sentence.len() - 1, sentence, &mut scores); } } - if let Some(scorer) = self.0.type_scorer.as_ref() { + if let Some(scorer) = self.data.type_scorer.as_ref() { debug_assert!(sentence.len() <= sentence.type_pma_states.len()); // token_id is always smaller than tag_weight.len() because tag_predictor is // created to contain such values in the new() function. @@ -603,7 +626,10 @@ impl Predictor { } } let i = sentence.len() - 1; - tag_predictor.predict(&scores, &mut sentence.tags[i * self.0.n_tags..]); + tag_predictor.predict(&scores, &mut sentence.tags[i * self.data.n_tags..]); + if !sentence.tag_scores.is_empty() { + sentence.tag_scores[i].replace((&tag_predictor.tags, scores)); + } } } } @@ -611,7 +637,7 @@ impl Predictor { /// Serializes the predictor into a Vec. pub fn serialize_to_vec(&self) -> Result> { let config = bincode::config::standard(); - let result = bincode::encode_to_vec(&self.0, config)?; + let result = bincode::encode_to_vec(&self.data, config)?; Ok(result) } @@ -625,7 +651,13 @@ impl Predictor { let config = bincode::config::standard(); // Deserialization is unsafe because the automaton will not be verified. let (predictor_data, size) = bincode::borrow_decode_from_slice(data, config)?; - Ok((Self(predictor_data), &data[size..])) + Ok(( + Self { + data: predictor_data, + tag_scores: false, + }, + &data[size..], + )) } } diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index 066dc9c..3d4dd06 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -91,6 +91,8 @@ pub struct Sentence<'a, 'b> { pub(crate) char_pma_states: Vec, pub(crate) type_pma_states: Vec, pub(crate) tags: Vec>>, + #[allow(clippy::type_complexity)] + pub(crate) tag_scores: Vec], Vec)>>, pub(crate) n_tags: usize, predictor: Option<&'b Predictor>, str_to_char_pos: Vec, @@ -120,6 +122,7 @@ impl<'a, 'b> Default for Sentence<'a, 'b> { char_pma_states: vec![], type_pma_states: vec![], tags: vec![], + tag_scores: vec![], n_tags: 0, predictor: None, str_to_char_pos: vec![], @@ -232,6 +235,7 @@ impl<'a, 'b> Sentence<'a, 'b> { type_pma_states: vec![], predictor: None, tags: vec![], + tag_scores: vec![], n_tags: 0, str_to_char_pos, char_to_str_pos, @@ -451,6 +455,7 @@ impl<'a, 'b> Sentence<'a, 'b> { type_pma_states: vec![], predictor: None, tags, + tag_scores: vec![], n_tags, str_to_char_pos, char_to_str_pos, @@ -688,6 +693,7 @@ impl<'a, 'b> Sentence<'a, 'b> { type_pma_states: vec![], predictor: None, tags, + tag_scores: vec![], n_tags, str_to_char_pos, char_to_str_pos, @@ -1203,6 +1209,41 @@ impl<'a, 'b> Token<'a, 'b> { &self.sentence.tags[start..end] } + /// Returns tag candidates with scores. + /// + /// The return value is a two-dimensional array. The outer array index corresponding to the + /// return value of [`Token::tags()`]. The inner array is a candidate set, where each element + /// is a tuple of the tag name and its score. + /// + /// # Panics + /// + /// This function panics if [`Predictor::store_tag_scores()`] is set to false. + #[cfg(feature = "tag-prediction")] + #[cfg_attr(docsrs, doc(cfg(feature = "tag-prediction")))] + pub fn tag_candidates(&self) -> Vec> { + let mut results = vec![]; + assert!( + !self.sentence.tag_scores.is_empty(), + "Predictor::store_tag_scores() must be set to true to use this function.", + ); + if let Some((tags, scores)) = self.sentence.tag_scores[self.end - 1].as_ref() { + let mut i = 0; + for cands in *tags { + let mut inner = vec![]; + if cands.len() == 1 { + inner.push((cands[0].as_str(), 0)); + } else { + for cand in cands { + inner.push((cand.as_str(), scores[i])); + i += 1; + } + } + results.push(inner); + } + } + results + } + /// Returns the start position of this token in characters. #[inline] pub const fn start(&self) -> usize {