-
Notifications
You must be signed in to change notification settings - Fork 11
/
utils.py
83 lines (67 loc) · 2.41 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import sys
import pickle
import random
def file_to_wordset(filename):
''' Converts a file with a word per line to a Python set '''
words = []
with open(filename, 'r') as f:
for line in f:
words.append(line.strip())
return set(words)
def write_status(i, total):
''' Writes status of a process to console '''
sys.stdout.write('\r')
sys.stdout.write('Processing %d/%d' % (i, total))
sys.stdout.flush()
def save_results_to_csv(results, csv_file):
''' Save list of type [(tweet_id, positive)] to csv in Kaggle format '''
with open("results/" + csv_file, 'w') as csv:
csv.write('id,prediction\n')
for tweet_id, pred in results:
csv.write(tweet_id)
csv.write(',')
csv.write(str(pred))
csv.write('\n')
def top_n_words(pkl_file_name, N, shift=0):
"""
Returns a dictionary of form {word:rank} of top N words from a pickle
file which has a nltk FreqDist object generated by stats.py
Args:
pkl_file_name (str): Name of pickle file
N (int): The number of words to get
shift: amount to shift the rank from 0.
Returns:
dict: Of form {word:rank}
"""
with open(pkl_file_name, 'rb') as pkl_file:
freq_dist = pickle.load(pkl_file)
most_common = freq_dist.most_common(N)
words = {p[0]: i + shift for i, p in enumerate(most_common)}
return words
def top_n_bigrams(pkl_file_name, N, shift=0):
"""
Returns a dictionary of form {bigram:rank} of top N bigrams from a pickle
file which has a Counter object generated by stats.py
Args:
pkl_file_name (str): Name of pickle file
N (int): The number of bigrams to get
shift: amount to shift the rank from 0.
Returns:
dict: Of form {bigram:rank}
"""
with open(pkl_file_name, 'rb') as pkl_file:
freq_dist = pickle.load(pkl_file)
most_common = freq_dist.most_common(N)
bigrams = {p[0]: i for i, p in enumerate(most_common)}
return bigrams
def split_data(tweets, validation_split=0.1):
"""Split the data into training and validation sets
Args:
tweets (list): list of tuples
validation_split (float, optional): validation split %
Returns:
(list, list): training-set, validation-set
"""
index = int((1 - validation_split) * len(tweets))
random.shuffle(tweets)
return tweets[:index], tweets[index:]