Skip to content

Commit

Permalink
ready to merge
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroAska committed Apr 7, 2018
2 parents 16db23b + a821bac commit 4c232da
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 23 deletions.
57 changes: 48 additions & 9 deletions data_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cv2
import sys, os
import cv2, sys, os, shutil, csv
from random import randint, sample
import numpy as np
import pdb
Expand Down Expand Up @@ -53,14 +52,24 @@ def __init__(self, location, file = None, flag = False, height = 224, width = 22
self.getGround()

def OSWalk(self):
images = [os.path.join(root, name) for root, dirs, files in os.walk(self.location)
images = [os.path.join(root, name) for root, dirs, files in os.walk(self.location + 'Train/')
for name in files if name.endswith((".png", ".jpg", ".jpeg", ".gif", ".tiff"))]
self.imageLocs = images


# for i in range(len(images)):
# a,b = images[i].split('\\')
# self.imageLocs[i] = a + '/' + b

if self.flag==False:
for i in range(len(images)):
a,b = images[i].split('/')
self.imageLocs[i] = a + '/' + b
else:
for i in range(len(images)):
self.imageLocs[i] = images[i]


# array to store which image has been used
self.numImages = len(self.imageLocs)
self.remImages = np.ones((self.numImages), dtype = bool)
Expand Down Expand Up @@ -110,16 +119,18 @@ def process(self, num, indices):
print('Generating Samples .... Note: no img normalization ')

self.sampleImages = np.zeros((num*self.num_crops , self.height, self.width, self.depth), dtype=np.uint8)

idx = 0
imgs = np.zeros((num, 256, 455, 3))
names = []
names = ['' for i in range(num)]
for i in range(num):
# read image
img = cv2.resize(cv2.imread(indices[i], 1).astype(float), (455,256), interpolation = cv2.INTER_CUBIC)
imgs[i] = img

name = self.getName(indices[i])
names.append(name)
if (name in self.img2labels ):
names[i] = name
imgs[i,:,:,:] = img

means = np.mean(imgs, axis=0)
imgs = imgs - means

Expand All @@ -146,10 +157,15 @@ def process(self, num, indices):
y = randint(0,230)

self.sampleImages[j, :, :, :] = img[x:x + self.height, y:y + self.width, :].copy()
self.idx2img[j] = names[i]
try:
self.idx2img[j] = names[i]
except IndexError:
pdb.set_trace()
idx += self.num_crops


for i in range(self.numImages):
if self.remImages[i] == False:
self.store(self.imageLocs[i])

def fetch(self, num):
samples = np.zeros((num, self.height, self.width, self.depth), dtype=np.float)
Expand Down Expand Up @@ -233,6 +249,29 @@ def getName(self, loc):
return int(name)


def store(self, image):
name = self.getName(image)
location = self.location + 'usedImages/'
# copy image to folder
if os.path.exists(location):
shutil.rmtree(location)
if not os.path.exists(location):
os.makedirs(location)
shutil.copy(image, location + str(name) + '.tiff')

# write image with labels to folder
file = location + 'trainingSet.csv'
csv = open(file, "a")
x = str(self.img2labels[name][0])
y = str(self.img2labels[name][1])
z = str(self.img2labels[name][2])
r = str(self.img2labels[name][3])
p = str(self.img2labels[name][4])
h = str(self.img2labels[name][5])

row = str(name) + ',' + x + ',' + y + ',' + z + ',' + r + ',' + p + ',' + h + '\n'
csv.write(row)

def reset(self):
self.remImages = np.ones((self.numImages), dtype = bool)

Expand Down
13 changes: 4 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,25 @@
# location = "./2013-01-10/"
# file = 'groundtruth_2013-01-10.csv'

location = "/media/eecs568/My Passport/NCLT/2012-03-17/2012-03-17/"
location = "/home/eecs568/eecs568/Mobile-Robotics/nclt/"
file = 'groundtruth_2012-03-17.csv'

dh = DH.Process(location, file, True)

# pick number of images to pick samples from
numImages = 100
numImages = 1000
dh.generateData(numImages)
dh.remimages()

dh.numsamples()
flag, images, labels = dh.fetch(2)
print(labels)
flag, images, labels = dh.fetch(6)
print(images.shape, len(labels), flag)
dh.remsamples()


# Trick to pick samples from selected images
numSamples = 60
numSamples = 10
flag = True
while flag==True:
flag, images, labels = dh.fetch(numSamples)
cv2.imshow('image', images[0,:,:,:])
cv2.waitKey(0)
pdb.set_trace()
dh.remsamples()

11 changes: 6 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys, os
sys.path.insert(0, '/home/eecs568/miniconda3/envs/tensorflow/lib/python3.5/site-packages')
#sys.path.insert(0, '/home/eecs568/miniconda3/envs/tensorflow/lib/python3.5/site-packages')
import data_handler
from datetime import datetime
import numpy as np
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self,path_to_weight, path_to_data, beta, use_quaternion=True, resum
print("Model initialized")

def init_data_handler(self,path_to_data):
self.data_handler = data_handler.Process(path_to_data)
self.data_handler = data_handler.Process(path_to_data, 'dataset_train.txt', False)
#self.data_handler = gen_data.get_data()

def load_weight(self,path_to_weight):
Expand Down Expand Up @@ -154,8 +154,9 @@ def test(self, img, num_random_crops=20):
def train(self, batch_size, epochs):

total_loss = 0

total_batch = int(self.data_handler.numimages() * self.data_handler.num_crops / batch_size) #100
total_batch = int(self.data_handler.numimages() * self.data_handler.num_crops * 1.0 / batch_size) #100
if total_batch==0:
pdb.set_trace()
#print("[trainer] Start Training, size of dataset is " +str(self.data_handler.numimages() * self.data_handler.num_crops ))
#pdb.set_trace()
for epoch in range(epochs):
Expand Down Expand Up @@ -200,5 +201,5 @@ def train(self, batch_size, epochs):
argv[3] = 100
argv[4] = bool(int(False))
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
train_thread = trainer(argv[1], argv[2], int(argv[3]), bool(int(argv[4])))
train_thread = trainer(argv[1], argv[2], int(argv[3]), use_quaternion=True, resume_training=False )
train_thread.train(32, 600)

0 comments on commit 4c232da

Please sign in to comment.