Skip to content

Latest commit

 

History

History
45 lines (45 loc) · 5.51 KB

README.md

File metadata and controls

45 lines (45 loc) · 5.51 KB

findphone

Find a cellphone in an image and determine the normalized coordinates of its center within the image.

Background

Classic localization problem given an untagged set of images and a set of labels identifying the normalized center point of the phone in the image. Task was to localize the phone within each image and report its center point in normalized coordinates. Employed transfer learning using frozen inference graph based upon the Faster RCNN RESNET model (faster_rcnn_resnet101_coco_2018_01_28) from the TensorFlow detection model zoo. This model has been evaluated at an mAP of 32 and a speed of 106 ms (relatively high precision and slow speed).

Overall Task

Step 1: Train the Model on Phone Images ("train_phone_finder")

This was accomplished using the TensorFlow Object Detection API, and utility software supplied with the API. Specifically, this was done by:

  1. Installing object detection API.
  2. Installing labelImg software.
  3. Tagging phones in training image set using labelImg, which saves bounding box in Pascal VOC XML format.
  4. Creating TensorFlow Record (tfr) from Pascal VOC XML files using modified version of create_pet_tf_record.py from the object detection API (located under cloned API dir at ./tensorflow/models/research/object_detection/dataset_tools). Initially did not realize that, in create_pet_tf_record.py, the class name of the object to detect was encoded in the filename. Hardcoded this to specify the class (i.e., "cellphone") of interest.
  5. Installing the Faster RCNN RESNET model from the zoo.
  6. Modifying pipeline.config to minimize batch sizes in the train.config section (batch_queue_capacity and prefetch_queue_capacity set to 4). I have a very old Dell Dimension E520 with little memory and the training would stall without these changes.
  7. Training on the sample images for 8+ hours until the total loss was less than ~0.05. Ended up with a checkpoint in the 254 range. Training results in a frozen inference graph that can be used for image classification and bounding box creation.

Step 2: Run Images through Model ("find_phone")

This was accomplished by modifying the TensorFlow object detection API tutorial file supplied with the API. Specifically, this was done by:

  1. Installing jupyter notebook software.
  2. Modifying stock notebook script. The tutorial file is located under the cloned API dir at ./tensorflow/models/research/object_detection/object_detection_tutorial.ipynb.
  3. Removed code that dealt with downloading TAR model files.
  4. Parameterized main code block to read image file passed in as command line arg. Added function to find center of bounding box.

Downloading the Inference Graph

The RESNET inference graph is a large file and was stored on a Google Drive account as a .tar.gz file. The file can be downloaded by contacting me - it must be placed in the directory with the two scripts below (train_phone_finder.py and find_phone.py). The GZ file does not need to be extracted (decompressed) before copying to the directory from which the phone scripts will be run.

Two Executable Scripts Submitted

Two scripts were developed to accomplish the task:

  1. train_phone_finder.py: Decompresses the phone_graph.pb TAR GZ file - that's it. The model has been pretrained on the sample training image set so there is no need for further processing. The task description indicates that the find_phone.py script can "use data in the local folder previously generated by train_phone_finder.py" and that is what it does. Script ignores command line args.
  2. find_phone.py: Runs image through inference graph to detect phone within image. Displays the coordinates of the center of the phone within the image. Script is run per the task description.

Results

The RESNET classifier did fine in terms of classification, but ran very slow on my old desktop computer (~44 seconds per classification). Accuracy on the training set was more than 93%. Should do well on the test image set. Other models such as SSD MobileNet should be investigated for achieving faster classification speeds. In addition, I did no pre-processing of the image files before classification, and even some basic normalization could improve results. And of course additional training images would improve classification accuracy further.

References

TensorFlow object detection API installation
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
TensorFlow detection model zoo
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
TensorFlow object detection API
https://github.com/tensorflow/models/tree/master/research/object_detection
LabelImage software
https://github.com/saicoco/object_labelImg
Is Google Tensorflow Object Detection API the easiest way to implement image recognition?
https://towardsdatascience.com/is-google-tensorflow-object-detection-api-the-easiest-way-to-implement-image-recognition-a8bd1f500ea0
Building a Toy Detector with Tensorflow Object Detection API
https://towardsdatascience.com/building-a-toy-detector-with-tensorflow-object-detection-api-63c0fdf2ac95
Step by Step TensorFlow Object Detection API Tutorial — Part 1: Selecting a Model https://medium.com/@WuStangDan/step-by-step-tensorflow-object-detection-api-tutorial-part-1-selecting-a-model-a02b6aabe39e
Build a Taylor Swift detector with the TensorFlow Object Detection API, ML Engine, and Swift
https://towardsdatascience.com/build-a-taylor-swift-detector-with-the-tensorflow-object-detection-api-ml-engine-and-swift-82707f5b4a56