Skip to content

Commit

Permalink
fixed train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nischalbasuti committed Jul 20, 2018
1 parent be6efcc commit 0cc7590
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
#!/usr/bin/env python
from patched_cnn import *

shadows = open_images("./segments/shadows")
non_shadows = open_images("./segments/non_shadows", len(shadows))
images = open_images("./data/SBUTrain4KRecoveredSmall/ShadowImages", 4070)
shadow_masks = open_images("./data/SBUTrain4KRecoveredSmall/ShadowMasks", 4070, True)

x = [] # input features.
y = [] # labels

x.extend(shadows)
x.extend(non_shadows)
x.extend(images)
y.extend(shadow_masks)

y.extend([ 1 for i in range(len(shadows)) ])
y.extend([ 0 for i in range(len(non_shadows)) ])

cnn = Patched_CNN()
cnn.build_model()

batch_size = 50
epochs = 50
patience = 5
cnn.train(x, y, batch_size, epochs, patience)
cnn.save_model()
prior_cnn = Patched_CNN()
prior_cnn.build_model(channels=3)
prior_cnn.train(
x, y,
batch_size=20,
epochs=100,
patience=5,
prefix="complete_images_")

0 comments on commit 0cc7590

Please sign in to comment.