From 744893ff6990b46cfd12f2d88fba915eef8c4951 Mon Sep 17 00:00:00 2001 From: Binh Vu Date: Mon, 28 Aug 2023 22:29:16 -0700 Subject: [PATCH] update mapreduce & scripts to get representative value of a class --- .vscode/tasks.json | 22 +++ Cargo.lock | 17 +- Cargo.toml | 11 +- benches/{rc_vs_py.rs => rc_vs_py.rs.tmp} | 0 data/.gitignore | 2 + src/error.rs | 2 + src/lib.rs | 1 + src/main.rs | 23 +++ src/mapreduce/dataset.rs | 73 +++++++- src/mapreduce/foldop.rs | 34 ++++ src/mapreduce/functions.rs | 64 ++++++- src/mapreduce/miscop.rs | 23 +++ src/mapreduce/mod.rs | 136 ++++++++++++--- src/mapreduce/sortop.rs | 32 ++-- src/models/property.rs | 2 +- src/pyo3helper/hashbrown.rs | 2 +- src/python/scripts.rs | 212 ++++++++++++++++++----- 17 files changed, 555 insertions(+), 101 deletions(-) create mode 100644 .vscode/tasks.json rename benches/{rc_vs_py.rs => rc_vs_py.rs.tmp} (100%) create mode 100644 data/.gitignore create mode 100644 src/main.rs create mode 100644 src/mapreduce/foldop.rs create mode 100644 src/mapreduce/miscop.rs diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..e911fe4 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,22 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "type": "cargo", + "command": "run", + "args": [ + "--package", + "kgdata", + "--bin", + "kgdata", + "--features", + "pyo3/auto-initialize" + ], + "problemMatcher": [ + "$rustc" + ], + "group": "build", + "label": "rust: run kgdata" + } + ] +} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 53eaa74..c15414b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -709,6 +709,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" dependencies = [ "ahash", + "rayon", "serde", ] @@ -817,6 +818,7 @@ dependencies = [ "glob", "hashbrown 0.13.2", "log", + "ord_subset", "petgraph", "pyo3", "rayon", @@ -1021,6 +1023,12 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "ord_subset" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7ce14664caf5b27f5656ff727defd68ae1eb75ef3c4d95259361df1eb376bef" + [[package]] name = "os_str_bytes" version = "6.4.1" @@ -1149,6 +1157,7 @@ checksum = "ffb88ae05f306b4bfcde40ac4a51dc0b05936a9207a4b75b798c7729c4258a59" dependencies = [ "anyhow", "cfg-if", + "hashbrown 0.13.2", "indoc", "inventory", "libc", @@ -1214,9 +1223,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" dependencies = [ "either", "rayon-core", @@ -1224,9 +1233,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.10.2" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ "crossbeam-channel", "crossbeam-deque", diff --git a/Cargo.toml b/Cargo.toml index d9f431c..8d507c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,11 +24,16 @@ flate2 = { version = "1.0.24", features = [ "zlib-ng", ], default-features = false } glob = "0.3.1" -hashbrown = { version = "0.13.2", features = ["serde"] } +hashbrown = { version = "0.13.2", features = ["serde", "rayon"] } log = "0.4.17" +ord_subset = "3.1.1" petgraph = "0.6.3" -pyo3 = { version = "0.19.1", features = ["anyhow", "multiple-pymethods"] } -rayon = "1.5.3" +pyo3 = { version = "0.19.1", features = [ + "anyhow", + "multiple-pymethods", + "hashbrown", +] } +rayon = "1.7.0" rocksdb = "0.20.1" serde = { version = "1.0.137", features = ["derive"] } serde_json = "1.0.81" diff --git a/benches/rc_vs_py.rs b/benches/rc_vs_py.rs.tmp similarity index 100% rename from benches/rc_vs_py.rs rename to benches/rc_vs_py.rs.tmp diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index dc9918d..d8edfc3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -39,3 +39,5 @@ pub fn into_pyerr>(err: E) -> PyErr { anyerror.into() } } + +pub type KGResult = Result; diff --git a/src/lib.rs b/src/lib.rs index e450fb2..cd40909 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ fn core(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(init_env_logger, m)?)?; python::models::register(py, m)?; + m.add_class::()?; Ok(()) } diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..86eba0e --- /dev/null +++ b/src/main.rs @@ -0,0 +1,23 @@ +use anyhow::Result; +use hashbrown::HashSet; +use kgdata::mapreduce::*; +use kgdata::python::scripts::GetRepresentativeValue; +use kgdata::{error::into_pyerr, mapreduce::from_jl_files, python::scripts::EntityTypesAndDegrees}; +use pyo3::prelude::*; + +fn main() -> PyResult<()> { + let args = GetRepresentativeValue { + data_dir: "/Volumes/research/kgdata/data/dbpedia/20221201".to_string(), + class_ids: HashSet::from_iter(vec!["http://dbpedia.org/ontology/Person".to_string()]), + kgname: "dbpedia".to_string(), + topk: 1000, + }; + + // Python::with_gil(|py| { + // let res = GetRepresentativeValue::calculate_stats(py, &args).unwrap(); + // println!("{:?}", res); + // }); + + println!("Hello, world!"); + Ok(()) +} diff --git a/src/mapreduce/dataset.rs b/src/mapreduce/dataset.rs index 7b5f2e4..348c876 100644 --- a/src/mapreduce/dataset.rs +++ b/src/mapreduce/dataset.rs @@ -1,17 +1,48 @@ -use std::path::PathBuf; - +use core::hash::Hash; +use hashbrown::HashMap; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use std::path::PathBuf; use crate::error::KGDataError; use super::{FromParallelDataset, ParallelDataset}; pub struct Dataset { - items: Vec, + pub items: Vec, +} + +pub struct MapDataset +where + K: Hash + Eq + Send, + V: Send, +{ + pub map: HashMap, +} + +pub struct RefDataset<'t, I> { + pub items: &'t Vec, } impl ParallelDataset for Dataset where I: Send {} +impl<'t, I> ParallelDataset for RefDataset<'t, I> where I: Sync + 't {} + +impl<'t, I> RefDataset<'t, I> +where + I: 't, +{ + pub fn new(items: &'t Vec) -> Self { + Self { items } + } +} + +impl ParallelDataset for MapDataset +where + K: Hash + Eq + Send, + V: Send, +{ +} + impl IntoParallelIterator for Dataset where I: Send, @@ -24,6 +55,31 @@ where } } +impl<'t, I> IntoParallelIterator for RefDataset<'t, I> +where + I: Sync + 't, +{ + type Iter = rayon::slice::Iter<'t, I>; + type Item = &'t I; + + fn into_par_iter(self) -> Self::Iter { + self.items.into_par_iter() + } +} + +impl IntoParallelIterator for MapDataset +where + K: Hash + Eq + Send, + V: Send, +{ + type Iter = hashbrown::hash_map::rayon::IntoParIter; + type Item = (K, V); + + fn into_par_iter(self) -> Self::Iter { + self.map.into_par_iter() + } +} + impl Dataset { pub fn files(glob: &str) -> Result { let items = glob::glob(glob)? @@ -61,3 +117,14 @@ where Ok(Dataset { items }) } } + +impl FromIterator for Dataset { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self { + items: iter.into_iter().collect::>(), + } + } +} diff --git a/src/mapreduce/foldop.rs b/src/mapreduce/foldop.rs new file mode 100644 index 0000000..bf51900 --- /dev/null +++ b/src/mapreduce/foldop.rs @@ -0,0 +1,34 @@ +use rayon::prelude::*; + +use super::ParallelDataset; + +#[derive(Clone)] +pub struct FoldOp { + pub base: D, + pub identity: ID, + pub op: F, +} + +impl IntoParallelIterator for FoldOp +where + D: ParallelDataset, + F: (Fn(T, D::Item) -> T) + Sync + Send, + ID: Fn() -> T + Sync + Send, + T: Send, +{ + type Iter = rayon::iter::Fold; + type Item = T; + + fn into_par_iter(self) -> Self::Iter { + self.base.into_par_iter().fold(self.identity, self.op) + } +} + +impl ParallelDataset for FoldOp +where + D: ParallelDataset, + F: (Fn(T, D::Item) -> T) + Sync + Send, + ID: Fn() -> T + Sync + Send, + T: Send, +{ +} diff --git a/src/mapreduce/functions.rs b/src/mapreduce/functions.rs index bca6952..72c8321 100644 --- a/src/mapreduce/functions.rs +++ b/src/mapreduce/functions.rs @@ -1,27 +1,47 @@ -use std::{fs::File, io::BufRead, io::BufReader, path::PathBuf}; +use std::{ffi::OsStr, fs::File, io::BufRead, io::BufReader, path::PathBuf}; -use rayon::prelude::*; +use flate2::read::GzDecoder; use serde::Deserialize; use crate::error::KGDataError; use super::*; -pub fn make_try_flat_map_fn(func: F) -> impl Fn(T) -> Vec> +pub fn make_begin_try_flat_map_fn(func: F) -> impl Fn(I) -> Vec> where - F: Fn(T) -> Result, - OPI: IntoParallelIterator, + F: Fn(I) -> Result, + OPI: IntoIterator, E: Send, { move |value| { let out = func(value); match out { - Ok(v) => v.into_par_iter().map(Ok).collect::>(), + Ok(v) => v.into_iter().map(Ok).collect::>(), Err(e) => vec![Err(e)], } } } +pub fn make_try_flat_map_fn( + func: F, +) -> impl Fn(Result) -> Vec> +where + F: Fn(I) -> Result, + OPI: IntoIterator, + E: Send, +{ + move |value| match value { + Ok(value) => { + let out = func(value); + match out { + Ok(v) => v.into_iter().map(Ok).collect::>(), + Err(e) => vec![Err(e)], + } + } + Err(e) => vec![Err(e)], + } +} + pub fn make_try_fn(func: F) -> impl Fn(Result) -> Result where F: Fn(I) -> Result, @@ -50,19 +70,45 @@ pub fn from_jl_files( where for<'de> T: Deserialize<'de> + Send, { - let ds = Dataset::files(glob)?.flat_map(make_try_flat_map_fn(deser_json_lines)); + let ds = Dataset::files(glob)?.flat_map(make_begin_try_flat_map_fn(deser_json_lines)); Ok(ds) } -fn deser_json_lines(path: PathBuf) -> Result, KGDataError> +pub fn deser_json_lines(path: PathBuf) -> Result, KGDataError> where for<'de> T: Deserialize<'de>, { + let ext = path.extension().and_then(OsStr::to_str).map(str::to_owned); let file = File::open(path)?; - let reader = BufReader::new(file); + let reader: Box = if let Some(ext) = ext { + match ext.as_str() { + "gz" => Box::new(BufReader::new(GzDecoder::new(file))), + _ => unimplemented!(), + } + } else { + Box::new(BufReader::new(file)) + }; reader .lines() .map(|line| serde_json::from_str::(&line?).map_err(|err| err.into())) .collect::, _>>() } + +pub fn merge_map_list( + mut map: HashMap>, + another: HashMap>, +) -> HashMap> +where + K: Hash + Eq, +{ + for (k, v) in another.into_iter() { + match map.get_mut(&k) { + Some(lst) => lst.extend(v), + None => { + map.insert(k, v); + } + } + } + map +} diff --git a/src/mapreduce/miscop.rs b/src/mapreduce/miscop.rs new file mode 100644 index 0000000..57bc118 --- /dev/null +++ b/src/mapreduce/miscop.rs @@ -0,0 +1,23 @@ +use rayon::prelude::*; + +use super::ParallelDataset; + +#[derive(Clone)] +pub struct TakeAny { + pub base: D, + pub n: usize, +} + +impl IntoParallelIterator for TakeAny +where + D: ParallelDataset, +{ + type Iter = rayon::iter::TakeAny; + type Item = D::Item; + + fn into_par_iter(self) -> Self::Iter { + self.base.into_par_iter().take_any(self.n) + } +} + +impl ParallelDataset for TakeAny where D: ParallelDataset {} diff --git a/src/mapreduce/mod.rs b/src/mapreduce/mod.rs index 1f1688c..21c1f76 100644 --- a/src/mapreduce/mod.rs +++ b/src/mapreduce/mod.rs @@ -1,16 +1,21 @@ +use core::hash::Hash; +use hashbrown::HashMap; use rayon::prelude::*; pub mod dataset; pub mod filterop; +pub mod foldop; pub mod functions; pub mod mapop; -// pub mod sortop; +pub mod miscop; +pub mod sortop; +pub use self::dataset::*; pub use self::filterop::*; pub use self::functions::*; pub use self::mapop::*; -// pub use self::sortop::*; -pub use self::dataset::*; +pub use self::miscop::*; +pub use self::sortop::*; /// A note on the implementation: due to the trait methods required Sized on most of the methods, /// if we use as trait object, we can't use most of its methods. To prevent early boxing error, we required @@ -35,38 +40,94 @@ pub trait ParallelDataset: Sized + Send + IntoParallelIterator { fn filter(self, op: F) -> self::filterop::FilterOp where F: Fn(&Self::Item) -> bool + Sync, - Self: Sized, { self::filterop::FilterOp { base: self, op } } - // fn sort_by_key(self, op: F, ascending: bool) -> self::sortop::SortByKeyOp - // where - // F: Fn(&Self::Item) -> K + Sync, - // K: Ord + Send, - // Self: Sized, - // { - // self::sortop::SortByKeyOp { - // base: self, - // op, - // ascending, - // } - // } + fn fold(self, identity: ID, op: F) -> self::foldop::FoldOp + where + F: (Fn(T, Self::Item) -> T) + Sync + Send, + ID: Fn() -> T + Sync + Send, + T: Send, + { + self::foldop::FoldOp { + base: self, + identity, + op, + } + } - fn collect(self) -> C + fn reduce(self, identity: ID, op: F) -> Self::Item where - C: FromParallelDataset, + F: (Fn(Self::Item, Self::Item) -> Self::Item) + Sync + Send, + ID: (Fn() -> Self::Item) + Sync + Send, { - C::from_par_dataset(self) + self.into_par_iter().reduce(identity, op) + } + + fn sort_by_key(self, op: F, ascending: bool) -> self::sortop::SortByKeyOp + where + F: Fn(&Self::Item) -> K + Sync, + K: Ord + Send, + { + self::sortop::SortByKeyOp { + base: self, + op, + ascending, + } + } + + fn group_by(self, key: F) -> MapDataset> + where + F: Fn(&Self::Item) -> K + Sync, + K: Hash + Eq + Send, + { + let map = self + .fold( + HashMap::new, + |mut map: HashMap>, item: Self::Item| { + map.entry(key(&item)).or_default().push(item); + map + }, + ) + .reduce(HashMap::new, merge_map_list); + + MapDataset { map } + } + + fn group_by_map(self, key: F1, value: F2) -> MapDataset> + where + F1: Fn(&Self::Item) -> K + Sync, + F2: Fn(&Self::Item) -> V + Sync, + K: Hash + Eq + Send, + V: Send, + { + let map = self + .fold( + HashMap::new, + |mut map: HashMap>, item: Self::Item| { + map.entry(key(&item)).or_default().push(value(&item)); + map + }, + ) + .reduce(HashMap::new, merge_map_list); + + MapDataset { map } + } + + fn count(self) -> usize { + self.into_par_iter().count() } - fn take(self, n: usize) -> Vec + fn take_any(self, n: usize) -> TakeAny { + TakeAny { base: self, n } + } + + fn collect(self) -> C where - Self: Sized, + C: FromParallelDataset, { - let mut res = self.collect::>(); - res.truncate(n); - res + C::from_par_dataset(self) } } @@ -88,6 +149,19 @@ where } } +impl FromParallelDataset<(K, V)> for HashMap +where + K: Hash + Eq + Send, + V: Send, +{ + fn from_par_dataset(dataset: D) -> Self + where + D: ParallelDataset, + { + dataset.into_par_iter().collect() + } +} + impl FromParallelDataset> for Result, E> where I: Send, @@ -101,6 +175,20 @@ where } } +impl FromParallelDataset> for Result, E> +where + K: Hash + Eq + Send, + V: Send, + E: Send, +{ + fn from_par_dataset(dataset: D) -> Self + where + D: ParallelDataset>, + { + dataset.into_par_iter().collect() + } +} + pub trait IntoParallelDataset { type Dataset: ParallelDataset; type Item: Send; diff --git a/src/mapreduce/sortop.rs b/src/mapreduce/sortop.rs index 8240871..e30ee45 100644 --- a/src/mapreduce/sortop.rs +++ b/src/mapreduce/sortop.rs @@ -9,21 +9,19 @@ pub struct SortByKeyOp { pub ascending: bool, } -// impl ParallelDataset for SortByKeyOp -// where -// D: ParallelDataset + IntoParallelIterator::Item>, -// F: Fn(&::Item) -> bool + Sync + Send, -// { -// type Item = ::Item; +impl SortByKeyOp +where + D: ParallelDataset, + F: Fn(&D::Item) -> bool + Sync + Send, +{ + fn collect(self) -> Vec { + let mut items: Vec = self.base.into_par_iter().collect(); + if self.ascending { + items.sort_unstable_by_key(self.op); + } else { + items.sort_unstable_by_key(|item| std::cmp::Reverse((self.op)(item))); + } -// fn collect(self) -> Vec { -// let mut items: Vec = self.base.into_par_iter().collect(); -// if self.ascending { -// items.sort_unstable_by_key(self.op); -// } else { -// items.sort_unstable_by_key(|item| std::cmp::Reverse((self.op)(item))); -// } - -// items -// } -// } + items + } +} diff --git a/src/models/property.rs b/src/models/property.rs index e6bdef3..cc47989 100644 --- a/src/models/property.rs +++ b/src/models/property.rs @@ -1,7 +1,7 @@ use crate::error::KGDataError; use super::{MultiLingualString, MultiLingualStringList}; -use hashbrown::{HashMap, HashSet}; +use hashbrown::HashMap; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/src/pyo3helper/hashbrown.rs b/src/pyo3helper/hashbrown.rs index e148a37..36e4df0 100644 --- a/src/pyo3helper/hashbrown.rs +++ b/src/pyo3helper/hashbrown.rs @@ -1,7 +1,7 @@ use hashbrown::HashSet; use pyo3::{ prelude::*, - types::{PyDict, PyList, PySet}, + types::{PyList, PySet}, }; /// An zero-cost abstraction for automatically receiving HashSet from Python. diff --git a/src/python/scripts.rs b/src/python/scripts.rs index 6e34d88..aae9435 100644 --- a/src/python/scripts.rs +++ b/src/python/scripts.rs @@ -1,51 +1,187 @@ -use std::{fs::File, io::BufReader, path::PathBuf}; - +use crate::error::KGResult; use crate::models::MultiLingualString; use crate::{error::into_pyerr, mapreduce::*}; use hashbrown::{HashMap, HashSet}; use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PyDict}; use serde::{Deserialize, Serialize}; -use std::io::BufRead; #[pyclass] pub struct GetRepresentativeValue { - types_and_degrees: HashMap, - id2labels: HashMap, + pub data_dir: String, + pub class_ids: HashSet, + pub kgname: String, + pub topk: usize, +} + +impl GetRepresentativeValue { + pub fn get_score(&self, ent: &EntityTypesAndDegrees, class_id: &str) -> f32 { + let outscale = 0.95; + let inscale = 1.0 - outscale; + let degree_scale = 0.1; + + let wp_score = ent.wikipedia_indegree.unwrap_or(0) as f32 * inscale + + ent.wikipedia_outdegree.unwrap_or(0) as f32 * outscale; + + let db_score = ent.indegree as f32 * inscale + ent.outdegree as f32 * outscale; + + let dist_score = match ent.types[class_id] { + 0 => 0.0, + 1 => -500.0, + 2 => -10000.0, + 3 => -30000.0, + dist => -(dist as f32) * 20000.0, + }; + + (wp_score + db_score) * degree_scale + dist_score + } + + pub fn get_ent_types_and_degrees_files(&self) -> String { + format!("{}/entity_types_and_degrees/*.gz", self.data_dir) + } + + pub fn get_ent_label_files(&self) -> String { + format!("{}/entity_labels/*.gz", self.data_dir) + } } #[pymethods] impl GetRepresentativeValue { #[new] - pub fn new(data_dir: &str, class_ids: Vec, kgname: &str) -> PyResult { - let filtered_ids: HashSet = HashSet::from_iter(class_ids); - let types_and_degrees = from_jl_files::(&format!( - "{}/entity_types_and_degrees/*.gz", - data_dir - )) - .map_err(into_pyerr)? - .filter(make_try_filter_fn(|x: &EntityTypesAndDegrees| { - !filtered_ids.is_disjoint(&x.types.keys().cloned().collect::>()) - })) - .map(make_try_fn(|x: EntityTypesAndDegrees| { - Ok((x.id.clone(), x)) - })) - .collect::, _>>() - .map_err(into_pyerr)?; - - let types_and_degrees = - HashMap::::from_iter(types_and_degrees); - - unimplemented!() + pub fn new(data_dir: String, class_ids: HashSet, kgname: String, topk: usize) -> Self { + Self { + data_dir, + class_ids, + kgname, + topk, + } } - // pub fn __call__(&self, class_ids: Vec, k: usize) -> Result, KGDataError> { + pub fn get_examples<'t>(&self, py: Python<'t>) -> PyResult<&'t PyDict> { + let matched_ents = + from_jl_files::(&self.get_ent_types_and_degrees_files()) + .map_err(into_pyerr)? + .filter(make_try_filter_fn(|x: &EntityTypesAndDegrees| { + x.types.iter().any(|cid| self.class_ids.contains(cid.0)) + })) + .collect::>>() + .map_err(into_pyerr)?; + + let type2examples = RefDataset::new(&matched_ents) + .flat_map(|ent| { + ent.types + .keys() + .filter_map(|cid| { + if self.class_ids.contains(cid) { + Some((cid, ent)) + } else { + None + } + }) + .collect::>() + }) + .group_by_map(|item| item.0, |item| item.1) + .map(|item| { + let mut newents = item + .1 + .into_iter() + .map(|ent| (ent, self.get_score(ent, item.0))) + .collect::>(); + newents.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + newents.truncate(self.topk); + (item.0, newents.into_iter().map(|x| x.0).collect::>()) + }) + .collect::>(); + + let matched_ent_ids = matched_ents + .iter() + .map(|ent| &ent.id) + .collect::>(); + + let matched_ent_labels = from_jl_files::(&self.get_ent_label_files()) + .map_err(into_pyerr)? + .filter(make_try_filter_fn(|x: &EntityLabel| { + matched_ent_ids.contains(&x.id) + })) + .map(make_try_fn(|x: EntityLabel| Ok((x.id.clone(), x)))) + .collect::>>() + .map_err(into_pyerr)?; + + let output = PyDict::new(py); - // Ok(()) - // } + for cid in &self.class_ids { + output.set_item( + cid, + type2examples[cid] + .iter() + .map(|ent| { + let dict = PyDict::new(py); + dict.set_item("id", &ent.id)?; + dict.set_item( + "label", + multi_lingual_string_to_dict(py, &matched_ent_labels[&ent.id].label)?, + )?; + dict.set_item("score", self.get_score(ent, cid))?; + Ok(dict) + }) + .collect::>>()?, + )?; + } + + Ok(output) + } + + /// Calculate the number of entities for each type. This is useful when we want to determine + /// what types we should group together in one pass. Some big types that require more memory + /// should run alone. + pub fn calculate_stats<'t>(&self, py: Python<'t>) -> PyResult<&'t PyDict> { + let type_counts = + from_jl_files::(&self.get_ent_types_and_degrees_files()) + .map_err(into_pyerr)? + .flat_map(make_try_flat_map_fn(|x: EntityTypesAndDegrees| { + Ok(self.class_ids.iter().filter_map(move |cid| { + if x.types.contains_key(cid) { + Some(cid.as_str()) + } else { + None + } + })) + })) + .fold( + || Ok(HashMap::new()), + |map: KGResult>, item: KGResult<&str>| { + let mut map = map?; + let item = item?; + if map.contains_key(item) { + *map.get_mut(item).unwrap() += 1; + } else { + map.insert(item, 1); + } + Ok(map) + }, + ) + .reduce( + || Ok(HashMap::new()), + |map: KGResult>, map2: KGResult>| { + let mut map = map?; + for (k, v) in map2?.into_iter() { + if map.contains_key(&k) { + *map.get_mut(&k).unwrap() += v; + } else { + map.insert(k, v); + } + } + Ok(map) + }, + ) + .map_err(into_pyerr)?; + + Ok(type_counts.into_py_dict(py)) + } } #[derive(Serialize, Deserialize, Debug, Clone)] -struct EntityTypesAndDegrees { +pub struct EntityTypesAndDegrees { id: String, types: HashMap, @@ -63,14 +199,12 @@ struct EntityLabel { label: MultiLingualString, } -fn deser_ent_types(path: PathBuf) -> PyResult> { - let file = File::open(path).map_err(into_pyerr)?; - let reader = BufReader::new(file); - reader - .lines() - .map(|line| { - serde_json::from_str::(&line.map_err(into_pyerr)?) - .map_err(into_pyerr) - }) - .collect::>>() +fn multi_lingual_string_to_dict<'t>( + py: Python<'t>, + label: &MultiLingualString, +) -> PyResult<&'t PyDict> { + let dict = PyDict::new(py); + dict.set_item("lang2value", (&label.lang2value).into_py_dict(py))?; + dict.set_item("lang", &label.lang)?; + Ok(dict) }