Skip to content

Commit

Permalink
Add vaporetto_tantivy (#53)
Browse files Browse the repository at this point in the history
* Add vaporetto_tantivy

* Add README

* fix

* Refactor

* fmt

* fix

* fix

* Add SplitLinebreaksFilter

* Add chars_and_boundaries_mut()
  • Loading branch information
vbkaisetsu authored Feb 14, 2022
1 parent 289224f commit cc620c0
Show file tree
Hide file tree
Showing 21 changed files with 669 additions and 147 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
members = [
"vaporetto",
"vaporetto_rules",
"vaporetto_tantivy",
"manipulate_model",
"predict",
"train",
Expand Down
6 changes: 2 additions & 4 deletions evaluate/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,11 @@ struct Opt {
fn main() -> Result<(), Box<dyn std::error::Error>> {
let opt = Opt::from_args();

let fullwidth_filter = KyteaFullwidthFilter::new();
let fullwidth_filter = KyteaFullwidthFilter;
let mut post_filters: Vec<Box<dyn SentenceFilter>> = vec![];
for wsconst in &opt.wsconst {
match wsconst {
WsConst::GraphemeCluster => {
post_filters.push(Box::new(ConcatGraphemeClustersFilter::new()))
}
WsConst::GraphemeCluster => post_filters.push(Box::new(ConcatGraphemeClustersFilter)),
WsConst::CharType(char_type) => {
post_filters.push(Box::new(KyteaWsConstFilter::new(*char_type)))
}
Expand Down
2 changes: 1 addition & 1 deletion manipulate_model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2018"

[dependencies]
csv = "1.1" # Unlicense OR MIT
csv = "1.1" # Unlicense or MIT
serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0
structopt = "0.3" # MIT or Apache-2.0
vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0
Expand Down
6 changes: 2 additions & 4 deletions predict/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let mut pre_filters: Vec<Box<dyn StringFilter>> = vec![];
if !opt.no_norm {
pre_filters.push(Box::new(KyteaFullwidthFilter::new()));
pre_filters.push(Box::new(KyteaFullwidthFilter));
}
let mut post_filters: Vec<Box<dyn SentenceFilter>> = vec![];
for wsconst in &opt.wsconst {
match wsconst {
WsConst::GraphemeCluster => {
post_filters.push(Box::new(ConcatGraphemeClustersFilter::new()))
}
WsConst::GraphemeCluster => post_filters.push(Box::new(ConcatGraphemeClustersFilter)),
WsConst::CharType(char_type) => {
post_filters.push(Box::new(KyteaWsConstFilter::new(*char_type)))
}
Expand Down
2 changes: 1 addition & 1 deletion train/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct Opt {
fn main() -> Result<(), Box<dyn std::error::Error>> {
let opt = Opt::from_args();

let fullwidth_filter = KyteaFullwidthFilter::new();
let fullwidth_filter = KyteaFullwidthFilter;

eprintln!("Loading dataset...");
let mut train_sents = vec![];
Expand Down
1 change: 0 additions & 1 deletion vaporetto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ repository = "https://github.com/legalforce-research/vaporetto"
readme = "README.md"
keywords = ["japanese", "analyzer", "tokenizer", "morphological"]
categories = ["text-processing"]
autotests = false

[dependencies]
daachorse = "0.4.0" # MIT or Apache-2.0
Expand Down
8 changes: 4 additions & 4 deletions vaporetto/src/char_scorer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::iter;
use std::rc::Rc;
use std::sync::Arc;

use daachorse::DoubleArrayAhoCorasick;

Expand Down Expand Up @@ -148,7 +148,7 @@ impl NaiveWeightSet {
boundary: None,
tag_left: None,
tag_right: None,
tag_self: Some(Rc::new(vec![TagRangeScore::new(
tag_self: Some(Arc::new(vec![TagRangeScore::new(
start_rel_position,
weight,
)])),
Expand All @@ -171,7 +171,7 @@ impl MergableWeight for NaiveWeightSet {
tag_self: utils::xor_or_zip_with(&weight1.tag_self, &weight2.tag_self, |w1, w2| {
let mut w = w1.to_vec();
w.append(&mut w2.to_vec());
Rc::new(w)
Arc::new(w)
}),
}
}
Expand Down Expand Up @@ -345,7 +345,7 @@ impl CharScorerWithTags {
.add_weight(&mut tag_ys.right_scores, offset);
}
if let Some(weight) = weight_set.tag_self.as_ref() {
tag_ys.self_scores[m_end - 1].replace(Rc::clone(weight));
tag_ys.self_scores[m_end - 1].replace(Arc::clone(weight));
}
}
}
Expand Down
37 changes: 20 additions & 17 deletions vaporetto/src/feature.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::hash::Hash;
use std::rc::Rc;
use std::sync::Arc;

use daachorse::DoubleArrayAhoCorasick;

Expand Down Expand Up @@ -213,7 +213,7 @@ impl<'a> TagFeature<'a> {
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct TagExample<'a> {
pub features: Vec<TagFeature<'a>>,
pub tag: Rc<String>,
pub tag: Arc<String>,
}

pub struct TagExampleGenerator {
Expand All @@ -240,8 +240,11 @@ impl TagExampleGenerator {
sentence.char_substring(start, sentence.chars.len()),
));
}
let mut current_tag: Option<Rc<String>> =
sentence.tags.last().and_then(|x| x.as_ref()).map(Rc::clone);
let mut current_tag: Option<Arc<String>> = sentence
.tags
.last()
.and_then(|x| x.as_ref())
.map(Arc::clone);
let mut tag_right_pos = sentence.chars.len();
for (i, (t, b)) in sentence
.tags
Expand Down Expand Up @@ -279,7 +282,7 @@ impl TagExampleGenerator {
features = vec![];
}
if let Some(tag) = t.as_ref() {
current_tag.replace(Rc::clone(tag));
current_tag.replace(Arc::clone(tag));
tag_right_pos = i + 1;
for j in
(i + 2)..(i + 2 + self.char_window_size).min(sentence.chars.len() + 1)
Expand Down Expand Up @@ -479,7 +482,7 @@ mod tests {
TagFeature::left_char_ngram_bos(-1, "Ar"),
TagFeature::chars("Aria"),
],
tag: Rc::new("名詞".to_string()),
tag: Arc::new("名詞".to_string()),
},
TagExample {
features: vec![
Expand All @@ -503,7 +506,7 @@ mod tests {
TagFeature::left_char_ngram(-1, "aは火"),
TagFeature::chars("は"),
],
tag: Rc::new("助詞".to_string()),
tag: Arc::new("助詞".to_string()),
},
TagExample {
features: vec![
Expand All @@ -520,7 +523,7 @@ mod tests {
TagFeature::left_char_ngram(-1, "猫だ"),
TagFeature::chars("だ"),
],
tag: Rc::new("助動詞".to_string()),
tag: Arc::new("助動詞".to_string()),
},
];

Expand Down Expand Up @@ -560,7 +563,7 @@ mod tests {
TagFeature::left_char_ngram_bos(-1, "Ar"),
TagFeature::chars("Aria"),
],
tag: Rc::new("名詞".to_string()),
tag: Arc::new("名詞".to_string()),
},
TagExample {
features: vec![
Expand All @@ -578,7 +581,7 @@ mod tests {
TagFeature::left_char_ngram(-1, "aは火"),
TagFeature::chars("は"),
],
tag: Rc::new("助詞".to_string()),
tag: Arc::new("助詞".to_string()),
},
TagExample {
features: vec![
Expand All @@ -592,7 +595,7 @@ mod tests {
TagFeature::left_char_ngram(-1, "猫だ"),
TagFeature::chars("だ"),
],
tag: Rc::new("助動詞".to_string()),
tag: Arc::new("助動詞".to_string()),
},
];

Expand Down Expand Up @@ -631,7 +634,7 @@ mod tests {
TagFeature::left_char_ngram_bos(-1, "A"),
TagFeature::chars("Aria"),
],
tag: Rc::new("名詞".to_string()),
tag: Arc::new("名詞".to_string()),
},
TagExample {
features: vec![
Expand All @@ -649,7 +652,7 @@ mod tests {
TagFeature::left_char_ngram(-1, "aは"),
TagFeature::chars("は"),
],
tag: Rc::new("助詞".to_string()),
tag: Arc::new("助詞".to_string()),
},
TagExample {
features: vec![
Expand All @@ -663,7 +666,7 @@ mod tests {
TagFeature::left_char_ngram(-1, "猫だ"),
TagFeature::chars("だ"),
],
tag: Rc::new("助動詞".to_string()),
tag: Arc::new("助動詞".to_string()),
},
];

Expand Down Expand Up @@ -704,7 +707,7 @@ mod tests {
TagFeature::left_char_ngram_bos(-1, "僕は"),
TagFeature::chars("僕"),
],
tag: Rc::new("代名詞".to_string()),
tag: Arc::new("代名詞".to_string()),
},
TagExample {
features: vec![
Expand All @@ -725,7 +728,7 @@ mod tests {
TagFeature::left_char_ngram(-1, "僕"),
TagFeature::chars("は"),
],
tag: Rc::new("助詞".to_string()),
tag: Arc::new("助詞".to_string()),
},
TagExample {
features: vec![
Expand All @@ -743,7 +746,7 @@ mod tests {
TagFeature::left_char_ngram(-1, "は"),
TagFeature::chars("人間"),
],
tag: Rc::new("名詞".to_string()),
tag: Arc::new("名詞".to_string()),
},
];

Expand Down
10 changes: 5 additions & 5 deletions vaporetto/src/predictor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::mem;

use std::cmp::Ordering;
use std::rc::Rc;
use std::sync::Arc;

use crate::char_scorer::{self, CharScorer, CharScorerWithTags};
use crate::errors::Result;
Expand All @@ -24,7 +24,7 @@ pub struct Predictor {
padding: usize,

// for tag prediction
tag_names: Vec<Rc<String>>,
tag_names: Vec<Arc<String>>,
tag_bias: Vec<i32>,
}

Expand All @@ -45,7 +45,7 @@ impl Predictor {

let char_scorer = if predict_tags {
for cls in model.tag_model.class_info {
tag_names.push(Rc::new(cls.name));
tag_names.push(Arc::new(cls.name));
tag_bias.push(cls.bias);
}
CharScorerWrapper::BoundaryAndTags(CharScorerWithTags::new(
Expand Down Expand Up @@ -142,8 +142,8 @@ impl Predictor {
sentence
}

fn best_tag(&self, scores: &[i32]) -> Rc<String> {
Rc::clone(
fn best_tag(&self, scores: &[i32]) -> Arc<String> {
Arc::clone(
scores
.iter()
.zip(&self.tag_names)
Expand Down
Loading

0 comments on commit cc620c0

Please sign in to comment.