-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgenerate_synthpai.py
146 lines (115 loc) · 4.83 KB
/
generate_synthpai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
import os
import hashlib
def comment_to_hex(comment, size=50):
# Compute hash of comment text
text = comment["text"].strip()
length = min(size, len(text))
prefix = text[:length] + comment["username"]
hash_object = hashlib.sha256(prefix.encode())
hex_dig = hash_object.hexdigest()
return hex_dig
def generate_thread_comments(comments):
comments_hex_to_id = {}
comments_hex_to_id_short = {}
comments_id_to_hex = {}
for comment in comments.values():
# Compute hash of comment text
comments_hex_to_id[comment_to_hex(comment)] = comment["id"]
comments_hex_to_id_short[comment_to_hex(comment, size=20)] = comment["id"]
comments_id_to_hex[comment["id"]] = comment_to_hex(comment)
def walk_thread(comment, thread_id, parent_id=None):
if parent_id is not None:
hex = comment_to_hex(comment)
if hex not in comments_hex_to_id:
print(
f"Comment not found: {comment['text'][:100]}"
) # Intermediary comment not found
hex_short = comment_to_hex(comment, size=20)
if hex_short not in comments_hex_to_id_short:
# We skip this comment
id = parent_id
else:
# Fallback for short modified comments
id = comments_hex_to_id_short[hex_short]
comment_obj = comments[id]
comment_obj["parent_id"] = parent_id
comment_obj["thread_id"] = thread_id
if "children" not in comment_obj:
comment_obj["children"] = []
else:
id = comments_hex_to_id[hex]
comment_obj = comments[id]
comment_obj["parent_id"] = parent_id
comment_obj["thread_id"] = thread_id
if "children" not in comment_obj:
comment_obj["children"] = []
if parent_id != "root":
parent_obj = comments[parent_id]
parent_obj["children"].append(id)
else:
id = "root"
try:
comment_obj = comments[comments_hex_to_id[comment_to_hex(comment)]]
comment_obj["thread_id"] = thread_id
comment_obj["parent_id"] = None
id = comments_hex_to_id[comment_to_hex(comment)]
comment_obj["children"] = []
except KeyError:
print(f"Root Comment not found: {comment['text'][:100]}")
# Set children
if "children" in comment:
for child in comment["children"]:
walk_thread(child, thread_id, id)
thread_folder = "data/thread/generated_threads/json_threads"
# Iterate through all files subfolders
for root, dirs, files in os.walk(thread_folder):
print(f"Processing file: {files}")
for file in files:
if file.endswith(".json"):
thread_path = os.path.join(root, file)
with open(thread_path, "r") as file:
data = json.loads(file.readlines()[0])
thread_id = thread_path.split("/")[-1].split(".")[0]
walk_thread(data, thread_id)
if __name__ == "__main__":
comments = {}
authors = {}
with open("data/thread/synth_clean.jsonl", "r") as file:
for line in file:
comment = json.loads(line)
id = comment["id"]
author = comment["author"]
comments[id] = comment
if author not in authors:
authors[author] = {
"username": comment["username"],
"profile": comment["profile"],
"comment_ids": [id],
}
else:
authors[author]["comment_ids"].append(id)
generate_thread_comments(comments)
ctr = 0
for comment in comments.values():
if "thread_id" not in comment:
ctr += 1
print(
f"Thread not found for comment: {comment['text'][:100].strip()} - {comment['id']}"
)
print(f"Number of comments without thread: {ctr}")
# Store updated comments
with open("data/synthpai.jsonl", "w") as file:
for comment in comments.values():
file.write(json.dumps(comment) + "\n")
# convert age guesses to str format
with open('data/synthpai.jsonl', 'r') as f:
data = [json.loads(line) for line in f]
for item in data:
if item.get('guesses') is not None:
for pred in item['guesses']:
if pred.get('feature') == 'age':
pred['guesses'] = list(map(str, pred['guesses']))
with open('data/synthpai.jsonl', 'w') as f:
for item in data:
f.write(json.dumps(item) + '\n')