-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathaimodel_ss.py
180 lines (152 loc) · 5.19 KB
/
aimodel_ss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#
# Copyright 2024 Ocean Protocol Foundation
# SPDX-License-Identifier: Apache-2.0
#
from typing import Optional, Tuple
from enforce_typing import enforce_types
from pdr_backend.util.strutil import StrMixin
CLASSIF_APPROACH_OPTIONS = [
"ClassifLinearLasso",
"ClassifLinearLasso_Balanced",
"ClassifLinearRidge",
"ClassifLinearRidge_Balanced",
"ClassifLinearElasticNet",
"ClassifLinearElasticNet_Balanced",
"ClassifLinearSVM",
"ClassifGaussianProcess",
"ClassifXgboost",
"ClassifConstant",
]
REGR_APPROACH_OPTIONS = [
"RegrLinearLS",
"RegrLinearLasso",
"RegrLinearRidge",
"RegrLinearElasticNet",
"RegrGaussianProcess",
"RegrXgboost",
"RegrConstant",
]
APPROACH_OPTIONS = CLASSIF_APPROACH_OPTIONS + REGR_APPROACH_OPTIONS
WEIGHT_RECENT_OPTIONS = ["10x_5x", "10000x", "None"]
BALANCE_CLASSES_OPTIONS = ["SMOTE", "RandomOverSampler", "None"]
CALIBRATE_PROBS_OPTIONS = [
"CalibratedClassifierCV_Sigmoid",
"CalibratedClassifierCV_Isotonic",
"None",
]
CALIBRATE_REGR_OPTIONS = ["CurrentYval", "None"]
class AimodelSS(StrMixin):
__STR_OBJDIR__ = ["d"]
@enforce_types
def __init__(self, d: dict):
"""d -- yaml_dict["aimodel_ss"]"""
self.d = d
# test inputs
if self.approach not in APPROACH_OPTIONS:
raise ValueError(self.approach)
if self.weight_recent not in WEIGHT_RECENT_OPTIONS:
raise ValueError(self.weight_recent)
if self.balance_classes not in BALANCE_CLASSES_OPTIONS:
raise ValueError(self.balance_classes)
if self.calibrate_probs not in CALIBRATE_PROBS_OPTIONS:
raise ValueError(self.calibrate_probs)
if self.calibrate_regr not in CALIBRATE_REGR_OPTIONS:
raise ValueError(self.calibrate_regr)
self.validate_train_every_n_epochs(self.train_every_n_epochs)
# --------------------------------
# validators -- add as needed, when setters are added
def validate_train_every_n_epochs(self, n: int):
if n <= 0:
raise ValueError(n)
# --------------------------------
# yaml properties
@property
def approach(self) -> str:
"""eg 'ClassifLinearRidge'"""
return self.d["approach"]
@property
def weight_recent(self) -> str:
"""eg '10x_5x'"""
return self.d["weight_recent"]
@property
def balance_classes(self) -> str:
"""eg 'SMOTE'"""
return self.d["balance_classes"]
@property
def train_every_n_epochs(self) -> int:
"""eg 1. Train every 5 epochs"""
return int(self.d["train_every_n_epochs"])
@property
def calibrate_probs(self) -> str:
"""eg 'CalibratedClassifierCV_Sigmoid'"""
return self.d["calibrate_probs"]
@property
def seed(self) -> Optional[int]:
return self.d.get("seed", None)
@property
def calc_imps(self) -> bool:
"""Calc feature importances"""
return self.d.get("calc_imps", True)
def calibrate_probs_skmethod(self, N: int) -> str:
"""
@description
Return the value for 'method' argument in sklearn
CalibratedClassiferCV().
@arguments
N -- number of samples
"""
if N < 200:
return "sigmoid"
c = self.calibrate_probs
if c == "CalibratedClassifierCV_Sigmoid":
return "sigmoid"
if c == "CalibratedClassifierCV_Isotonic":
return "isotonic"
raise ValueError(c)
@property
def calibrate_regr(self) -> str:
"""eg 'CurrentYval'"""
return self.d["calibrate_regr"]
# --------------------------------
# derivative properties
@property
def do_regr(self) -> bool:
return self.approach[:4] == "Regr"
@property
def weight_recent_n(self) -> Tuple[int, int]:
"""@return -- (n_repeat1, n_repeat2)"""
if self.weight_recent == "None":
return 0, 0
if self.weight_recent == "10x_5x":
return 10, 5
if self.weight_recent == "10000x":
return 10000, 0
raise ValueError(self.weight_recent)
# --------------------------------
# setters (only add as needed)
def set_train_every_n_epochs(self, n: int):
self.validate_train_every_n_epochs(n)
self.d["train_every_n_epochs"] = n
# =========================================================================
# utilities for testing
@enforce_types
def aimodel_ss_test_dict(
approach: Optional[str] = None,
weight_recent: Optional[str] = None,
balance_classes: Optional[str] = None,
calibrate_probs: Optional[str] = None,
calibrate_regr: Optional[str] = None,
train_every_n_epochs: Optional[int] = None,
) -> dict:
"""Use this function's return dict 'd' to construct AimodelSS(d)"""
d = {
"approach": approach or "ClassifLinearRidge",
"weight_recent": weight_recent or "10x_5x",
"balance_classes": balance_classes or "SMOTE",
"calibrate_probs": calibrate_probs or "CalibratedClassifierCV_Sigmoid",
"calibrate_regr": calibrate_regr or "None",
"train_every_n_epochs": (
1 if train_every_n_epochs is None else train_every_n_epochs
),
}
return d