diff --git a/data_handler.py b/data_handler.py index 58ac51f..aca43ad 100644 --- a/data_handler.py +++ b/data_handler.py @@ -24,7 +24,7 @@ class Process: def __init__(self, location, height = 224, width = 224, depth = 3): self.img2labels = dict() self.idx2img = dict() - + self.num_crops=1 self.location = location self.height = height self.width = width @@ -80,41 +80,61 @@ def generateData(self, num): ctr += 1 # array to store which sample has been used - self.numSamples = 128*num + self.numSamples = self.num_crops*num self.remSamples = np.ones((self.numSamples), dtype = bool) # generate samples self.process(num, indices) + def centeredCrop(self, img, output_side_length): + height, width, depth = img.shape + new_height = output_side_length + new_width = output_side_length + if height > width: + new_height = output_side_length * height / width + else: + new_width = output_side_length * width / height + height_offset = (new_height - output_side_length) / 2 + width_offset = (new_width - output_side_length) / 2 + cropped_img = img[height_offset:height_offset + output_side_length, + width_offset:width_offset + output_side_length] + return cropped_img def process(self, num, indices): print('Generating Samples .... ') - self.sampleImages = np.zeros((num*128, self.height, self.width, self.depth), dtype=np.uint8) + self.sampleImages = np.zeros((num*self.num_crops , self.height, self.width, self.depth), dtype=np.uint8) idx = 0 + imgs = np.zeros((num, 256, 455, 3)) + names = [] for i in range(num): # read image img = cv2.resize(cv2.imread(indices[i], 1).astype(float), (455,256), interpolation = cv2.INTER_CUBIC) + imgs[i] = img name = self.getName(indices[i]) - - temp_mean = np.mean(img, axis=0) - temp_mean = np.mean(temp_mean, axis=0) - temp_std = np.zeros(self.depth) - - for i in range(self.depth): - img[:,:,i] -= temp_mean[i] - temp_std[i] = np.std(img[:,:,i]) - img[:,:,i] /= temp_std[i] - + names.append(name) + means = np.mean(imgs, axis=0) + imgs = imgs - means + #temp_mean = np.mean(img, axis=0) + #temp_mean = np.mean(temp_mean, axis=0) + #temp_std = np.zeros(self.depth) + + #for i in range(self.depth): + # img[:,:,i] -= temp_mean[i] + # temp_std[i] = np.std(img[:,:,i]) + # img[:,:,i] /= temp_std[i] + + for i in range(num): # generate 128 random indices for crop - for j in range(idx, idx+128): + for j in range(idx, idx+self.num_crops ): x = randint(0,31) y = randint(0,230) - self.sampleImages[j, :, :, :] = img[x:x + self.height, y:y + self.width, :].copy() - self.idx2img[j] = name - idx += 128 + img = imgs[i, :,:,:] + self.sampleImages[j, :, :, :] = self.centeredCrop(img, 224)#img[x:x + self.height, y:y + self.width, :].copy() + self.idx2img[j] = names[i] + idx += self.num_crops @@ -138,7 +158,10 @@ def fetch(self, num): # gaussian normalization of image to have mean 0, variance 1 temp = self.sampleImages[idx, :, :, :].astype(float) - labels[ctr] = self.img2labels[self.idx2img[idx]] + try: + labels[ctr] = self.img2labels[self.idx2img[idx]] + except KeyError: + pdb.set_trace() ctr += 1 return [flag, samples, labels] diff --git a/train.py b/train.py index 48d93ad..38469ab 100644 --- a/train.py +++ b/train.py @@ -29,17 +29,18 @@ def delete_network_backups(filename_prefix): class trainer(): - def __init__(self,path_to_weight, path_to_data, resume_training=False): + def __init__(self,path_to_weight, path_to_data, beta, resume_training=False): self.network_input_size = 224 - self.image_inputs = tf.placeholder(tf.float32, [None, self.network_input_size, self.network_input_size, 3]) - self.label_inputs = tf.placeholder(tf.float32, [None, 7]) # [ X Y Z W P Q R] self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) if resume_training: self.restore_network(path_to_weight) else: + self.image_inputs = tf.placeholder(tf.float32, [None, self.network_input_size, self.network_input_size, 3]) + self.label_inputs = tf.placeholder(tf.float32, [None, 7]) # [ X Y Z W P Q R] + self.network = vgg.VGG16({'data': self.image_inputs}) self.regen_regression_network() - self.build_loss() + self.build_loss(beta) self.saver = tf.train.Saver() self.merged_summary = tf.summary.merge_all() @@ -58,8 +59,8 @@ def __init__(self,path_to_weight, path_to_data, resume_training=False): print("Model initialized") def init_data_handler(self,path_to_data): - #self.data_handler = data_handler.Process(path_to_data) - self.data_handler = gen_data.get_data() + self.data_handler = data_handler.Process(path_to_data) + #self.data_handler = gen_data.get_data() def load_weight(self,path_to_weight): self.network.load(path_to_weight, self.sess) @@ -95,13 +96,15 @@ def regen_regression_network(self): self.network.variable_summaries(self.regression_out, "regression_output_") def restore_network(self, path_to_weight): - self.saver = tf.train.import_meta_graph(path_to_weight + ".meta" ) graph = tf.get_default_graph() - self.regression_out = graph.get_operation_by_name("regression_output") + self.regression_out = tf.get_default_graph().get_tensor_by_name('fc9/fc9:0') self.loss = graph.get_operation_by_name("final_loss") self.train_op = tf.get_default_graph().get_operation_by_name("Adam_minimizer") - self.saver.restore(self.sess, tf.train.latest_checkpoint('./')) + self.saver.restore(self.sess, path_to_weight)#tf.train.latest_checkpoint('./')) + self.image_inputs = tf.get_default_graph().get_tensor_by_name('Placeholder:0') + self.label_inputs = tf.get_default_graph().get_tensor_by_name('Placeholder_1:0') + print("Model restored.") def build_loss(self, beta=100): @@ -127,7 +130,7 @@ def build_loss(self, beta=100): self.network.variable_summaries(self.rotation_loss, "rotation_loss_") self.network.variable_summaries(self.loss, "final_weighted_loss_") - def test(self, img, num_random_crops=10): + def test(self, img, num_random_crops=20): if img.shape[2] != 3: print ("We only accept 3-dimensional rgb images") if img.shape[0] != self.network_input_size or img.shape[1] != self.network_input_size: @@ -151,17 +154,18 @@ def test(self, img, num_random_crops=10): def train(self, batch_size, epochs): total_loss = 0 - num_crops_per_img = 128 - total_batch = 100#int(self.data_handler.numimages() * num_crops_per_img / batch_size) - #print("[trainer] Start Training, size of dataset is " + str(334/32)) #+str(self.data_handler.numimages() * num_crops_per_img )) - pdb.set_trace() + + total_batch = int(self.data_handler.numimages() * self.data_handler.num_crops / batch_size) #100 + #print("[trainer] Start Training, size of dataset is " +str(self.data_handler.numimages() * self.data_handler.num_crops )) + #pdb.set_trace() for epoch in range(epochs): - #self.data_handler.reset() - #self.data_handler.generateData(500) - data_gen = gen_data.gen_data_batch(self.data_handler ) + self.data_handler.reset() + self.data_handler.generateData(500) + #data_gen = gen_data.gen_data_batch(self.data_handler ) for i in range(total_batch): - ''' + data_runout_flag, one_batch_image , one_batch_label = self.data_handler.fetch(batch_size) + ''' if data_runout_flag == False: if self.data_handler.remimages() > 0: self.data_handler.generateData(500) @@ -170,15 +174,15 @@ def train(self, batch_size, epochs): self.data_handler.generateData(500) data_runout_flag, one_batch_image , one_batch_label = self.data_handler.fetch(batch_size) ''' - one_batch_image, np_poses_x, np_poses_q = next(data_gen) - one_batch_label = np.hstack((np_poses_x, np_poses_q)) + #one_batch_image, np_poses_x, np_poses_q = next(data_gen) + #one_batch_label = np.hstack((np_poses_x, np_poses_q)) feeds ={self.image_inputs: one_batch_image, self.label_inputs: one_batch_label } summary, loss, gradients = self.sess.run([self.merged_summary, self.loss, self.compute_gradients ], feed_dict=feeds) self.sess.run([self.train_op], feed_dict=feeds ) print("[Epoch "+str(epoch)+" trainer] Train one batch of size "+str(batch_size)+", loss is "+str(loss)) total_loss += loss self.train_writer.add_summary(summary, epoch * total_batch + i) - + avg_loss = (total_loss)/total_batch self.saver.save(self.sess, "./model_epoch_"+str(epoch)+".ckpt") if epoch > 0: delete_network_backups("./model_epoch_"+str(epoch-1)+".ckpt" ) @@ -189,12 +193,12 @@ def train(self, batch_size, epochs): if __name__ == "__main__": argv = sys.argv - tf.logging.set_verbosity(tf.logging.ERROR) - if len(sys.argv) < 4: - argv = ['', '', '', ''] + if len(sys.argv) < 5: + argv = ['' for _ in range(5)] argv[1] = './vgg.data' argv[2] = './ShopFacade/' - argv[3] = bool(int(False)) + argv[3] = 100 + argv[4] = bool(int(False)) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - train_thread = trainer(argv[1], argv[2], bool(int(argv[3]))) + train_thread = trainer(argv[1], argv[2], int(argv[3]), bool(int(argv[4]))) train_thread.train(32, 600)