diff --git a/train.py b/train.py index b25d303..6b9459f 100644 --- a/train.py +++ b/train.py @@ -1,23 +1,20 @@ #!/usr/bin/env python from patched_cnn import * -shadows = open_images("./segments/shadows") -non_shadows = open_images("./segments/non_shadows", len(shadows)) +images = open_images("./data/SBUTrain4KRecoveredSmall/ShadowImages", 4070) +shadow_masks = open_images("./data/SBUTrain4KRecoveredSmall/ShadowMasks", 4070, True) x = [] # input features. y = [] # labels -x.extend(shadows) -x.extend(non_shadows) +x.extend(images) +y.extend(shadow_masks) -y.extend([ 1 for i in range(len(shadows)) ]) -y.extend([ 0 for i in range(len(non_shadows)) ]) - -cnn = Patched_CNN() -cnn.build_model() - -batch_size = 50 -epochs = 50 -patience = 5 -cnn.train(x, y, batch_size, epochs, patience) -cnn.save_model() +prior_cnn = Patched_CNN() +prior_cnn.build_model(channels=3) +prior_cnn.train( + x, y, + batch_size=20, + epochs=100, + patience=5, + prefix="complete_images_")