-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstrategy.py
54 lines (39 loc) · 1.63 KB
/
strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Strategy:
def __init__(self, X, Y, idxs_lb, X_val, Y_val, model, args, device, writer):
self.X = X
self.Y = Y
self.X_val = X_val
self.Y_val = Y_val
self.idxs_lb = idxs_lb
self.device = device
self.model = model
self.args = args
self.n_pool = len(Y)
self.writer = writer
self.query_count = 0
def query(self, n):
pass
def update(self, idxs_lb):
self.idxs_lb = idxs_lb
def train(self, name):
self.model.train(name, self.X, self.Y, self.idxs_lb, self.X_val, self.Y_val)
def predict(self, X, Y):
return self.model.predict(X, Y)
def predict_prob(self, X, Y):
return self.model.predict_prob(X, Y)
def predict_prob_embed(self, X, Y, eval=True):
return self.model.predict_prob_embed(X, Y, eval)
def predict_all_representations(self, X, Y):
return self.model.predict_all_representations(X, Y)
def predict_embedding_prob(self, X_embedding):
return self.model.predict_embedding_prob(X_embedding)
def predict_prob_dropout(self, X, Y, n_drop):
return self.model.predict_prob_dropout(X, Y, n_drop)
def predict_prob_dropout_split(self, X, Y, n_drop):
return self.model.predict_prob_dropout_split(X, Y, n_drop)
def predict_prob_embed_dropout_split(self, X, Y, n_drop):
return self.model.predict_prob_embed_dropout_split(X, Y, n_drop)
def get_embedding(self, X, Y):
return self.model.get_embedding(X, Y)
def get_grad_embedding(self, X, Y, is_embedding=False):
return self.model.get_grad_embedding(X, Y, is_embedding)