Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroAska committed Apr 7, 2018
1 parent 35e2c95 commit 213277f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 44 deletions.
59 changes: 41 additions & 18 deletions data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Process:
def __init__(self, location, height = 224, width = 224, depth = 3):
self.img2labels = dict()
self.idx2img = dict()

self.num_crops=1
self.location = location
self.height = height
self.width = width
Expand Down Expand Up @@ -80,41 +80,61 @@ def generateData(self, num):
ctr += 1

# array to store which sample has been used
self.numSamples = 128*num
self.numSamples = self.num_crops*num
self.remSamples = np.ones((self.numSamples), dtype = bool)

# generate samples
self.process(num, indices)

def centeredCrop(self, img, output_side_length):
height, width, depth = img.shape
new_height = output_side_length
new_width = output_side_length
if height > width:
new_height = output_side_length * height / width
else:
new_width = output_side_length * width / height
height_offset = (new_height - output_side_length) / 2
width_offset = (new_width - output_side_length) / 2
cropped_img = img[height_offset:height_offset + output_side_length,
width_offset:width_offset + output_side_length]
return cropped_img


def process(self, num, indices):
print('Generating Samples .... ')

self.sampleImages = np.zeros((num*128, self.height, self.width, self.depth), dtype=np.uint8)
self.sampleImages = np.zeros((num*self.num_crops , self.height, self.width, self.depth), dtype=np.uint8)

idx = 0
imgs = np.zeros((num, 256, 455, 3))
names = []
for i in range(num):
# read image
img = cv2.resize(cv2.imread(indices[i], 1).astype(float), (455,256), interpolation = cv2.INTER_CUBIC)
imgs[i] = img
name = self.getName(indices[i])

temp_mean = np.mean(img, axis=0)
temp_mean = np.mean(temp_mean, axis=0)
temp_std = np.zeros(self.depth)

for i in range(self.depth):
img[:,:,i] -= temp_mean[i]
temp_std[i] = np.std(img[:,:,i])
img[:,:,i] /= temp_std[i]

names.append(name)
means = np.mean(imgs, axis=0)
imgs = imgs - means
#temp_mean = np.mean(img, axis=0)
#temp_mean = np.mean(temp_mean, axis=0)
#temp_std = np.zeros(self.depth)

#for i in range(self.depth):
# img[:,:,i] -= temp_mean[i]
# temp_std[i] = np.std(img[:,:,i])
# img[:,:,i] /= temp_std[i]

for i in range(num):
# generate 128 random indices for crop
for j in range(idx, idx+128):
for j in range(idx, idx+self.num_crops ):
x = randint(0,31)
y = randint(0,230)
self.sampleImages[j, :, :, :] = img[x:x + self.height, y:y + self.width, :].copy()
self.idx2img[j] = name
idx += 128
img = imgs[i, :,:,:]
self.sampleImages[j, :, :, :] = self.centeredCrop(img, 224)#img[x:x + self.height, y:y + self.width, :].copy()
self.idx2img[j] = names[i]
idx += self.num_crops



Expand All @@ -138,7 +158,10 @@ def fetch(self, num):

# gaussian normalization of image to have mean 0, variance 1
temp = self.sampleImages[idx, :, :, :].astype(float)
labels[ctr] = self.img2labels[self.idx2img[idx]]
try:
labels[ctr] = self.img2labels[self.idx2img[idx]]
except KeyError:
pdb.set_trace()
ctr += 1

return [flag, samples, labels]
Expand Down
56 changes: 30 additions & 26 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ def delete_network_backups(filename_prefix):

class trainer():

def __init__(self,path_to_weight, path_to_data, resume_training=False):
def __init__(self,path_to_weight, path_to_data, beta, resume_training=False):
self.network_input_size = 224
self.image_inputs = tf.placeholder(tf.float32, [None, self.network_input_size, self.network_input_size, 3])
self.label_inputs = tf.placeholder(tf.float32, [None, 7]) # [ X Y Z W P Q R]
self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
if resume_training:
self.restore_network(path_to_weight)
else:
self.image_inputs = tf.placeholder(tf.float32, [None, self.network_input_size, self.network_input_size, 3])
self.label_inputs = tf.placeholder(tf.float32, [None, 7]) # [ X Y Z W P Q R]

self.network = vgg.VGG16({'data': self.image_inputs})
self.regen_regression_network()
self.build_loss()
self.build_loss(beta)
self.saver = tf.train.Saver()

self.merged_summary = tf.summary.merge_all()
Expand All @@ -58,8 +59,8 @@ def __init__(self,path_to_weight, path_to_data, resume_training=False):
print("Model initialized")

def init_data_handler(self,path_to_data):
#self.data_handler = data_handler.Process(path_to_data)
self.data_handler = gen_data.get_data()
self.data_handler = data_handler.Process(path_to_data)
#self.data_handler = gen_data.get_data()

def load_weight(self,path_to_weight):
self.network.load(path_to_weight, self.sess)
Expand Down Expand Up @@ -95,13 +96,15 @@ def regen_regression_network(self):
self.network.variable_summaries(self.regression_out, "regression_output_")

def restore_network(self, path_to_weight):

self.saver = tf.train.import_meta_graph(path_to_weight + ".meta" )
graph = tf.get_default_graph()
self.regression_out = graph.get_operation_by_name("regression_output")
self.regression_out = tf.get_default_graph().get_tensor_by_name('fc9/fc9:0')
self.loss = graph.get_operation_by_name("final_loss")
self.train_op = tf.get_default_graph().get_operation_by_name("Adam_minimizer")
self.saver.restore(self.sess, tf.train.latest_checkpoint('./'))
self.saver.restore(self.sess, path_to_weight)#tf.train.latest_checkpoint('./'))
self.image_inputs = tf.get_default_graph().get_tensor_by_name('Placeholder:0')
self.label_inputs = tf.get_default_graph().get_tensor_by_name('Placeholder_1:0')

print("Model restored.")

def build_loss(self, beta=100):
Expand All @@ -127,7 +130,7 @@ def build_loss(self, beta=100):
self.network.variable_summaries(self.rotation_loss, "rotation_loss_")
self.network.variable_summaries(self.loss, "final_weighted_loss_")

def test(self, img, num_random_crops=10):
def test(self, img, num_random_crops=20):
if img.shape[2] != 3:
print ("We only accept 3-dimensional rgb images")
if img.shape[0] != self.network_input_size or img.shape[1] != self.network_input_size:
Expand All @@ -151,17 +154,18 @@ def test(self, img, num_random_crops=10):
def train(self, batch_size, epochs):

total_loss = 0
num_crops_per_img = 128
total_batch = 100#int(self.data_handler.numimages() * num_crops_per_img / batch_size)
#print("[trainer] Start Training, size of dataset is " + str(334/32)) #+str(self.data_handler.numimages() * num_crops_per_img ))
pdb.set_trace()

total_batch = int(self.data_handler.numimages() * self.data_handler.num_crops / batch_size) #100
#print("[trainer] Start Training, size of dataset is " +str(self.data_handler.numimages() * self.data_handler.num_crops ))
#pdb.set_trace()
for epoch in range(epochs):
#self.data_handler.reset()
#self.data_handler.generateData(500)
data_gen = gen_data.gen_data_batch(self.data_handler )
self.data_handler.reset()
self.data_handler.generateData(500)
#data_gen = gen_data.gen_data_batch(self.data_handler )
for i in range(total_batch):
'''

data_runout_flag, one_batch_image , one_batch_label = self.data_handler.fetch(batch_size)
'''
if data_runout_flag == False:
if self.data_handler.remimages() > 0:
self.data_handler.generateData(500)
Expand All @@ -170,15 +174,15 @@ def train(self, batch_size, epochs):
self.data_handler.generateData(500)
data_runout_flag, one_batch_image , one_batch_label = self.data_handler.fetch(batch_size)
'''
one_batch_image, np_poses_x, np_poses_q = next(data_gen)
one_batch_label = np.hstack((np_poses_x, np_poses_q))
#one_batch_image, np_poses_x, np_poses_q = next(data_gen)
#one_batch_label = np.hstack((np_poses_x, np_poses_q))
feeds ={self.image_inputs: one_batch_image, self.label_inputs: one_batch_label }
summary, loss, gradients = self.sess.run([self.merged_summary, self.loss, self.compute_gradients ], feed_dict=feeds)
self.sess.run([self.train_op], feed_dict=feeds )
print("[Epoch "+str(epoch)+" trainer] Train one batch of size "+str(batch_size)+", loss is "+str(loss))
total_loss += loss
self.train_writer.add_summary(summary, epoch * total_batch + i)

avg_loss = (total_loss)/total_batch
self.saver.save(self.sess, "./model_epoch_"+str(epoch)+".ckpt")
if epoch > 0: delete_network_backups("./model_epoch_"+str(epoch-1)+".ckpt" )
Expand All @@ -189,12 +193,12 @@ def train(self, batch_size, epochs):

if __name__ == "__main__":
argv = sys.argv
tf.logging.set_verbosity(tf.logging.ERROR)
if len(sys.argv) < 4:
argv = ['', '', '', '']
if len(sys.argv) < 5:
argv = ['' for _ in range(5)]
argv[1] = './vgg.data'
argv[2] = './ShopFacade/'
argv[3] = bool(int(False))
argv[3] = 100
argv[4] = bool(int(False))
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
train_thread = trainer(argv[1], argv[2], bool(int(argv[3])))
train_thread = trainer(argv[1], argv[2], int(argv[3]), bool(int(argv[4])))
train_thread.train(32, 600)

0 comments on commit 213277f

Please sign in to comment.