forked from chuchro3/Warfarin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_final_plot.py
138 lines (104 loc) · 6.74 KB
/
make_final_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import matplotlib.pyplot as plt
import pickle
import numpy as np
# #def lineplotCI(x_data, y_data, sorted_x, low_CI, upper_CI, x_label, y_label, title, show=False):
def lineplotCI(cutoff, data,
#x_data, y_data, sorted_x, low_CI, upper_CI,
x_label, y_label, title, show=False):
# Create the plot object
# Plot the data, set the linewidth, color and transparency of the
# line, provide a label for the legend
for name, color, y_data, low_CI, upper_CI in data:
plt.plot(range(len(y_data))[cutoff:], y_data[cutoff:],
lw = 1, color = color, alpha = 1, label=name)
# Shade the confidence interval
plt.fill_between(range(len(y_data))[cutoff:], low_CI[cutoff:], upper_CI[cutoff:],
color = color, alpha = 0.4)
# Label the axes and provide a title
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
# Display legend
plt.legend(loc = 'best')
if show:
plt.show()
def get_CI_data(file_names, flip_neg = False):
data_files = list(map(lambda x: 'batch/'+x, file_names))
data =[]
run0 = pickle.load(open(data_files[0], 'rb'))
for e in run0:
data.append([e])
for f in data_files[1:]:
run = pickle.load(open(f, 'rb'))
for i,e in enumerate(run):
data[i].append(e)
if flip_neg:
data = -np.array(data)
mean = np.mean(data, axis=1)
std = np.std(data, axis=1)
Z = 1.960
lower = mean - Z * std / np.sqrt(len(data_files))
upper = mean + Z * std / np.sqrt(len(data_files))
return mean, lower, upper
#-----------ERROR--------
lin_ucb_error_data_files = ['errorLinear UCB440799',
'errorLinear UCB624516',
'errorLinear UCB708433',
'errorLinear UCB777131',
'errorLinear UCB972487',
'errorLinear UCB572277',
'errorLinear UCB674927',
'errorLinear UCB734551',
'errorLinear UCB783172',
'errorLinear UCB973196']
lin_ucb_error_data = get_CI_data(lin_ucb_error_data_files, True)
fixed_error_data_files = 'errorFixed232923 errorFixed280014 errorFixed294651 errorFixed386298 errorFixed485167 errorFixed596273 errorFixed694588 errorFixed720326 errorFixed942363 errorFixed997216'.split()
fixed_error_data = get_CI_data(fixed_error_data_files)
clinical_error_data_files = 'errorWarfarinClinicalDose109457 errorWarfarinClinicalDose473473 errorWarfarinClinicalDose578688 errorWarfarinClinicalDose616370 errorWarfarinClinicalDose723731 errorWarfarinClinicalDose334008 errorWarfarinClinicalDose487420 errorWarfarinClinicalDose584850 errorWarfarinClinicalDose644573 errorWarfarinClinicalDose790409'.split()
clinical_error_data = get_CI_data(clinical_error_data_files)
pharm_error_data_files = 'errorWarfarinPharmacogeneticDose286728 errorWarfarinPharmacogeneticDose409989 errorWarfarinPharmacogeneticDose507431 errorWarfarinPharmacogeneticDose668561 errorWarfarinPharmacogeneticDose802457 errorWarfarinPharmacogeneticDose352938 errorWarfarinPharmacogeneticDose458765 errorWarfarinPharmacogeneticDose661494 errorWarfarinPharmacogeneticDose708887 errorWarfarinPharmacogeneticDose896116'.split()
pharm_error_data = get_CI_data(pharm_error_data_files)
lasso_error_data_files = 'errorLASSO_Bandit_nodis128985 errorLASSO_Bandit_nodis245874 errorLASSO_Bandit_nodis419803 errorLASSO_Bandit_nodis466501 errorLASSO_Bandit_nodis644341 errorLASSO_Bandit_nodis187417 errorLASSO_Bandit_nodis382761 errorLASSO_Bandit_nodis438574 errorLASSO_Bandit_nodis582776 errorLASSO_Bandit_nodis970244'.split()
lasso_error_data = get_CI_data(lasso_error_data_files)
lin_ucb_disease_error_data_files = 'errorLinUCB_dis176459 errorLinUCB_dis448861 errorLinUCB_dis506081 errorLinUCB_dis668883 errorLinUCB_dis796048 errorLinUCB_dis229609 errorLinUCB_dis487298 errorLinUCB_dis593515 errorLinUCB_dis787344 errorLinUCB_dis942446'.split()
lin_ucb_disease_error_data = get_CI_data(lin_ucb_disease_error_data_files)
lasso_dis_error_data_files = 'errorLASSO_Bandit_dis122109 errorLASSO_Bandit_dis421003 errorLASSO_Bandit_dis492863 errorLASSO_Bandit_dis602104 errorLASSO_Bandit_dis773923 errorLASSO_Bandit_dis208179 errorLASSO_Bandit_dis427547 errorLASSO_Bandit_dis558990 errorLASSO_Bandit_dis743081 errorLASSO_Bandit_dis897925'.split()
lasso_dis_error_data = get_CI_data(lasso_dis_error_data_files)
lineplotCI(100,
[('Lin UCB', '#539caf', *lin_ucb_error_data),
('Fixed', '#FFA07A', *fixed_error_data),
('Clinical Oracle', '#228B22', *clinical_error_data),
('Pharmacogenetic Oracle', '#EE82EE', *pharm_error_data),
('LASSO Bandit', '#FFD700', *lasso_error_data),
('LASSO w/ D. features', '#18CAF5', *lasso_dis_error_data),
#('Lin UCB w/ D. features', '#18CAF5', *lin_ucb_disease_error_data),
],
'timestep t', 'cumulative error', 'Error Rate', True)
#--------END ERROR--------
lin_ucb_regret_files = map(lambda x: x.replace('error', 'regret'), lin_ucb_error_data_files)
lin_ucb_regret_data = get_CI_data(lin_ucb_regret_files)
fixed_regret_files = map(lambda x: x.replace('error', 'regret'), fixed_error_data_files)
fixed_regret_data = get_CI_data(fixed_regret_files)
clinical_regret_files = map(lambda x: x.replace('error', 'regret'), clinical_error_data_files)
clinical_regret_data = get_CI_data(clinical_regret_files)
pharm_regret_files = map(lambda x: x.replace('error', 'regret'), pharm_error_data_files)
pharm_regret_data = get_CI_data(pharm_regret_files)
lasso_regret_files = map(lambda x: x.replace('error', 'regret'), lasso_error_data_files)
lasso_regret_data = get_CI_data(lasso_regret_files)
lin_ucb_disease_regret_files = map(lambda x: x.replace('error', 'regret'), lin_ucb_disease_error_data_files)
lin_ucb_disease_regret_data = get_CI_data(lin_ucb_disease_regret_files)
lasso_dis_regret_files = map(lambda x: x.replace('error', 'regret'), lasso_dis_error_data_files)
lasso_dis_regret_data = get_CI_data(lasso_dis_regret_files)
lineplotCI(100,
[('Lin UCB', '#539caf', *lin_ucb_regret_data),
('Fixed', '#FFA07A', *fixed_regret_data),
('Clinical Oracle', '#228B22', *clinical_regret_data),
('Pharmacogenetic Oracle', '#EE82EE', *pharm_regret_data),
('LASSO Bandit', '#FFD700', *lasso_regret_data),
#('Lin UCB w/ D. features', '#18CAF5', *lin_ucb_disease_regret_data),
('Lasso w/ D. features', '#18CAF5', *lasso_dis_regret_data),
('y = 0.39x', '#000000', list(map(lambda x: 0.39 * x, range(0, 5528))),
list(map(lambda x: 0.39 * x, range(0, 5528))),
list(map(lambda x: 0.39 * x, range(0, 5528)))),
],
'timestep t', 'cumulative regret', 'Regret', True)