-
Notifications
You must be signed in to change notification settings - Fork 5
/
propagate_train.py
executable file
·140 lines (128 loc) · 6.31 KB
/
propagate_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python
import numpy as np
import os
import tempfile
import time
import sklearn
import sklearn.cross_validation
import sklearn.grid_search
import sklearn.metrics
from joblib import Parallel, delayed
import tsh.obsolete as tsh; logger = tsh.create_logger(__name__)
from utils import read_argsfile, read_listfile, read_truthfile, read_weightsfile, write_propagatorfile, clean_args, select
from semisupervised import propagate_labels
method_table = {
'harmonic': { 'function': lambda p, w, **kw: propagate_labels(p, w, method_name='harmonic', **kw) },
'general': { 'function': lambda p, w, **kw: propagate_labels(p, w, method_name='general', **kw) }
}
scoring_table = {
'accuracy': sklearn.metrics.accuracy_score
}
def train(method_name, method_args, data, n_jobs=None, output_dir=None):
args = method_args.copy()
hyper_params = method_args['hyper_params']
grid = sklearn.grid_search.ParameterGrid(dict(zip(hyper_params, [ args[p] for p in hyper_params ])))
verbose = True
if n_jobs == None:
n_jobs = 1
cv_results = Parallel(n_jobs=n_jobs, verbose=verbose,
pre_dispatch='2*n_jobs')(
delayed(fit_and_score)(
d, method_name, args, propagate_params, verbose, output_dir)
for propagate_params in grid for d in data)
n_grid_points = len(list(grid))
scores_mean = np.zeros(n_grid_points)
scores_std = np.zeros(n_grid_points)
params = []
for i in range(n_grid_points):
grid_start = i * len(data)
scores = [score[0] for score in cv_results[grid_start:grid_start + len(data)]]
scores_mean[i] = np.mean(scores)
scores_std[i] = np.std(scores)
params += [cv_results[grid_start][1]]
for j in range(grid_start, grid_start + len(data)):
assert cv_results[j][1] == cv_results[grid_start][1]
best_n = np.argmax(scores_mean)
# XXX: take model with smallest variance amongst the best
args['truth'] = data[0]['meta']['truth']
args[args['truth'] + '_labels'] = data[0]['meta'][args['truth'] + '_labels']
args['cv_results'] = cv_results
args['mean_score'] = scores_mean[best_n]
args['std_score'] = scores_std[best_n]
model = { 'method_name': method_name }
logger.info('Best model: %s', str(params[best_n]))
model.update(params[best_n])
return args, model
def get_weights(dissim, distance_factor):
weights = np.exp(-dissim * distance_factor)
weights[np.diag_indices_from(weights)] = 0.
return weights
def fit_and_score(data, method_name, method_args, propagate_params, verbose, output_dir):
args = method_args.copy()
args.update(data['meta'])
args.update(propagate_params)
label_name = args['truth'] + '_labels'
labels = args[label_name]
weights = get_weights(data['dissim'], propagate_params['bandwidth'])
propagate_fn = method_table[method_name]['function']
propagated = propagate_fn(data['pred'], weights, labels=labels, output_dir=output_dir, **args)
propagated = select(propagated, 'id', data['truth_ids'])
assert len(propagated['pred']) == len(data['target'])
score = scoring_table[method_args['scoring']](propagated['pred'], data['target'])
logger.info('%s: %d propagated samples with ground truth yielded: %.3f', str(propagate_params), len(propagated), score)
cm = sklearn.metrics.confusion_matrix(data['target'], propagated['pred'], labels=sorted(labels.keys()))
return score, propagate_params, cm
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Train label propagation on all the given data.')
parser.add_argument('-c', '--config', dest='config', required=False, action='store', default=None, help='Path to the config file')
parser.add_argument('-m', '--method', dest='method', required=True, action='store', choices=method_table.keys(), default=None, help='Method name.')
parser.add_argument('-a', '--args', dest='args', required=False, action='store', default=None, help='Method arguments file.')
parser.add_argument('-j', '--jobs', dest='jobs', required=False, action='store', default=None, type=int, help='Number of parallel processes.')
parser.add_argument('-d', '--dissimilarities', dest='dissim', nargs='*', required=True, action='store', default=None, help='Dissimilarities file(s).')
parser.add_argument('-p', '--predictions', dest='predictions', nargs='*', required=True, action='store', default=None, help='Predictions file(s).')
parser.add_argument('-t', '--truth', dest='truth', nargs='*', required=True, action='store', default=None, help='Truth file(s).')
parser.add_argument('--random-seed', dest='seed', required=False, action='store', type=int, default=-1, help='Random seed, by default use time.')
parser.add_argument('-o', '--output', dest='output', required=False, action='store', default=None, help='Output directory.')
opts = parser.parse_args()
if opts.output == None:
outdir = tempfile.mkdtemp(dir=os.curdir, prefix='out')
else:
outdir = opts.output
if not os.path.exists(outdir):
tsh.makedirs(outdir)
config = tsh.read_config(opts, __file__)
method_args = {}
if opts.args != None:
method_args.update(read_argsfile(opts.args))
if opts.seed == -1:
seed = int(time.time()*1024*1024)
else:
seed = opts.seed
np.random.seed(seed)
data = []
for prediction_name, dissim_name, truth_name in zip(opts.predictions, opts.dissim, opts.truth):
meta = {}
truth_meta, truth_ids, target = read_truthfile(truth_name)
meta.update(truth_meta)
m, predictions = read_listfile(prediction_name)
meta.update(m)
assert np.in1d(predictions['id'], np.array(truth_ids)).sum() > 0
m, dissim_ids, dissim = read_weightsfile(dissim_name)
assert (predictions['id'] == np.array(dissim_ids)).all()
meta.update(m)
data += [{
'meta': meta,
'pred': predictions,
'dissim': dissim,
'truth_ids': truth_ids,
'target': target
}]
args, model = train(opts.method, method_args, data, n_jobs=opts.jobs, output_dir=outdir)
args['random_generator_seed'] = seed
clean_args(args)
write_propagatorfile(os.path.join(outdir, 'propagator.dat'), {
'propagator': model,
'meta': args,
'data': data
})