-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
52 lines (39 loc) · 1.23 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import segmentation_models as seg_models
from keras.losses import categorical_crossentropy
from keras.optimizers import Adam
from keras_contrib.losses import jaccard_distance
from utils import tversky_index
MODELS = {
'unet': seg_models.Unet,
'fpn': seg_models.FPN,
'pspnet': seg_models.PSPNet,
'linknet': seg_models.Linknet
}
def tversky_loss(t, p):
return 1 - tversky_index(t, p)
def combined_loss(t, p):
return categorical_crossentropy(t, p) + tversky_loss(t, p)
def build(**kwargs):
architecture = kwargs['architecture']
backbone = kwargs['backbone']
input_shape = kwargs['input_shape']
num_classes = kwargs['num_classes']
freeze_encoder = kwargs.get('freeze_encoder', True)
learning_rate = kwargs.get('learning_rate', 1e-5)
weights = kwargs.get('weights', None)
model = MODELS[architecture](
backbone_name=backbone,
input_shape=input_shape,
classes=num_classes,
activation='softmax',
encoder_freeze=freeze_encoder,
)
if weights:
model.load_weights(weights)
model.compile(
optimizer=Adam(lr=learning_rate),
loss=combined_loss,
metrics=['categorical_accuracy']
)
model.summary()
return model