-
Notifications
You must be signed in to change notification settings - Fork 1
/
target.py
81 lines (62 loc) · 2.16 KB
/
target.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import abc
import lightgbm as lgb
import numpy as np
import torch
from sorel_nets import PENetwork
class AbstractTarget(abc.ABC):
def __init__(self, model_path, thresh):
self.model_endpoint = model_path
self.model_threshold = thresh
@abc.abstractmethod
def __call__(self, X):
raise NotImplementedError
class LGBTarget(AbstractTarget):
""""
Class for Ember and Sorel-20M LightGBM models
"""
def __init__(self, model_path, thresh, name):
super().__init__(model_path, thresh)
self.name = name
self.model = lgb.Booster(model_file=self.model_endpoint)
def __call__(self, X):
scores = self.model.predict(X)
# output = np.atleast_2d(scores)
return np.array([int(score > self.model_threshold) for score in scores])
class TorchTarget(AbstractTarget):
""""
Class for Sorel-20M FCNN models
"""
def __init__(self, model_path, thresh, name):
super().__init__(model_path, thresh)
self.name = name
self.model = PENetwork(use_malware=True, use_counts=False, use_tags=True, n_tags=11,
feature_dimension=2381)
self.model.load_state_dict(torch.load(self.model_endpoint))
# Set model to inference mode
self.model.eval()
"""
From sorel-20m code
"""
def _features_postproc_func(self, x):
x1 = np.copy(x)
lz = x1 < 0
gz = x1 > 0
x1[lz] = - np.log(1 - x1[lz])
x1[gz] = np.log(1 + x1[gz])
return x1
def __call__(self, X):
X = torch.from_numpy(self._features_postproc_func(X))
predictions = self.model(X)
scores = predictions["malware"].detach().numpy().ravel()
return np.array([int(score > self.model_threshold) for score in scores])
class FileBasedTarget(AbstractTarget):
"""
Class for targets that we have offline labels, such as AVs
"""
def __init__(self, model_path, name, labels):
self.name = name
self.labels = labels
def __call__(self, idx):
scores = self.labels[idx]
self.labels = np.delete(self.labels, idx, axis=0)
return scores