-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
25 lines (24 loc) · 1.07 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from binary_model import BinaryModel
from multiclass_model import MulticlassCNNModel, MulticlassNNModel, MulticlassCNNDropoutModel, MulticlassCNNReluActivationDropoutModel, CNN6LayerModel, CNN4LayerModel, VGGLikeModel
"""
Get model by name.
"""
def get_model(model, run_name):
if model == "binary":
return BinaryModel(run_name=run_name)
elif model == "nn":
return MulticlassNNModel(outputs=4, run_name=run_name)
elif model == "cnn":
return MulticlassCNNModel(outputs=4, run_name=run_name)
elif model == "dropout":
return MulticlassCNNDropoutModel(outputs=4, run_name=run_name)
elif model == "6layer":
return CNN6LayerModel(outputs=4, run_name=run_name)
elif model == "4layer":
return CNN4LayerModel(outputs=4, run_name=run_name)
elif model == "vgglike":
return VGGLikeModel(outputs=4, run_name=run_name)
elif model == "reludropout":
return MulticlassCNNReluActivationDropoutModel(output=4, run_name=run_name)
else:
return MulticlassCNNReluActivationDropoutModel(outputs=4, run_name=run_name)