-
Notifications
You must be signed in to change notification settings - Fork 0
/
join_categories.py
65 lines (54 loc) · 2 KB
/
join_categories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from src.config import Config
def misinterpret(text: str, false_encoding: str = "latin1") -> str:
return text.encode("utf-8").decode(false_encoding)
def get_all_parents(
joined: pd.DataFrame,
top_level: pd.Series,
num_levels: int = 20,
) -> pd.DataFrame:
parent_list = []
new_parents = joined[joined["parent"].isin(top_level)]
orphans = joined[
~joined["parent"].isin(top_level) & ~joined["child"].isin(top_level)
]
for _ in tqdm(range(num_levels)):
parent_list.append(new_parents)
new_parents = (
orphans.merge(new_parents, how="left", left_on="parent", right_on="child")
.rename(columns={"child_x": "child", "parent_y": "parent"})[
["child", "parent"]
]
.dropna()
)
orphans = orphans[~orphans["child"].isin(new_parents["child"])]
all_parents = pd.concat(parent_list).drop_duplicates().reset_index(drop=True)
return all_parents
def main(args: argparse.Namespace):
config: Config = Config.from_json(args.config_path)
prefix = config.prefix
ids = pd.read_csv(f"local_data/{prefix}wiki-category-ids.csv")
links = pd.read_csv(f"local_data/{prefix}wiki-latest-categorylinks.csv")
joined = links.merge(ids, left_on="cl_from", right_on="cat_id").rename(
columns={"cat_title": "child", "cl_to": "parent"},
)[["child", "parent"]]
top_level = joined.loc[
joined["parent"] == misinterpret(config.top_level),
"child",
].values
all_parents = get_all_parents(joined, top_level)
all_parents.to_csv(f"local_data/{prefix}wiki-all-parents.csv", index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Join categories with their parents.",
)
parser.add_argument(
"--config-path",
type=Path,
default=Path("da-config.json"),
)
args = parser.parse_args()
main(args=args)