Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] separate image message & OpenCV image classification, use dictionary for classes #12

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/python_package.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,16 @@ An extension of `SceneDetectionActionServer` which uses `SingleImageDetectionHan
from an image extracted from a `sensor_msgs/PointCloud2` message, while also fitting planes in the clouds.

## [`utils.py`](../ros/src/mas_perception_libs/utils.py)
* `get_classes_in_data_dir`: Returns a list of strings as class names for a directory. This directory structure
* `get_classes_in_data_dir`: Returns a dictionary mapping from indices to classes as names of top level directories.
This directory structure
```
data
├── class_1
└── class_2
```
should returns
minhnh marked this conversation as resolved.
Show resolved Hide resolved
```
['class_1', 'class_2']
{0: 'class_1', 1: 'class_2'}
```
when called on `data`.
* `process_image_message`: Converts `sensor_msgs/Image` to CV image, then resizes and/or runs a preprocessing function
Expand Down
54 changes: 29 additions & 25 deletions ros/src/mas_perception_libs/image_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import yaml
from abc import ABCMeta, abstractmethod
import numpy as np
from cv_bridge import CvBridge
Expand All @@ -11,7 +12,7 @@ class ImageClassifier(object):
"""
__metaclass__ = ABCMeta

_classes = None # type: list
_classes = None # type: dict

def __init__(self, **kwargs):
# read information on classes, either directly, via a file, or from a data directory
Expand All @@ -20,7 +21,11 @@ def __init__(self, **kwargs):
if self._classes is None:
class_file = kwargs.get('class_file', None)
if class_file is not None and os.path.exists(class_file):
self._classes = ImageClassifier.read_classes_from_file(class_file)
with open(class_file) as infile:
if yaml.__version__ < '5.1':
self._classes = yaml.load(infile)
else:
self._classes = yaml.load(infile, Loader=yaml.FullLoader)

if self._classes is None:
data_dir = kwargs.get('data_dir', None)
Expand All @@ -34,7 +39,7 @@ def __init__(self, **kwargs):
@property
def classes(self):
"""
list of strings containing class names TODO(minhnh): make this dictionary from predicted class to class name
dictionary mapping from predicted numeric class value to class name
"""
return self._classes

Expand All @@ -52,17 +57,6 @@ def classify(self, image_messages):
"""
pass

@staticmethod
def write_classes_to_file(classes, outfile_path):
with open(outfile_path, 'w') as outfile:
outfile.write('\n'.join(classes))

@staticmethod
def read_classes_from_file(infile):
with open(infile) as f:
content = f.readlines()
return [x.strip() for x in content]


class ImageClassifierTest(ImageClassifier):
"""
Expand Down Expand Up @@ -113,26 +107,36 @@ def __init__(self, **kwargs):
# CvBridge for ROS image conversion
self._cv_bridge = CvBridge()

def classify(self, image_messages):
if len(image_messages) == 0:
return [], [], []

np_images = [process_image_message(msg, self._cv_bridge, self._target_size, self._img_preprocess_func)
for msg in image_messages]

image_array = []
def classify_np_images(self, np_images):
"""
Classify NumPy images
"""
image_tensor = []
indices = []
for i in range(len(np_images)):
if np_images[i] is None:
# skip broken images
continue

image_array.append(np_images[i])
image_tensor.append(np_images[i])
indices.append(i)

image_array = np.array(image_array)
preds = self._model.predict(image_array)
image_tensor = np.array(image_tensor)
preds = self._model.predict(image_tensor)
class_indices = np.argmax(preds, axis=1)
confidences = np.max(preds, axis=1)
predicted_classes = [self._classes[i] for i in class_indices]

return indices, predicted_classes, confidences

def classify(self, image_messages):
"""
Classify ROS `sensor_msgs/Image` messages
"""
if len(image_messages) == 0:
return [], [], []

np_images = [process_image_message(msg, self._cv_bridge, self._target_size, self._img_preprocess_func)
for msg in image_messages]

return self.classify_np_images(np_images)
2 changes: 1 addition & 1 deletion ros/src/mas_perception_libs/image_recognition_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def handle_recognize_image(self, req):
rospy.loginfo('number of images to recognize: ' + str(len(req.images)))
if req.model_name not in self._classifiers:
model_path = os.path.join(self._model_dir, req.model_name + '.h5')
class_file = os.path.join(self._model_dir, req.model_name + '.txt')
class_file = os.path.join(self._model_dir, req.model_name + '.yml')
rospy.loginfo('recognition model path: ' + model_path)
rospy.loginfo('recognition class file path: ' + class_file)
self._classifiers[req.model_name] = self._classifier_class(model_path=model_path, class_file=class_file)
Expand Down
10 changes: 6 additions & 4 deletions ros/src/mas_perception_libs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ def get_bag_file_msg_by_type(bag_file_path, msg_type):
def get_classes_in_data_dir(data_dir):
"""
:type data_dir: str
:return: list of classes as names of top level directories
:return: dictionary mapping from indices to classes as names of top level directories
"""
classes = []
class_dict = {}
index = 0
for subdir in sorted(os.listdir(data_dir)):
if os.path.isdir(os.path.join(data_dir, subdir)):
classes.append(subdir)
class_dict[index] = subdir
index += 1

return classes
return class_dict


def process_image_message(image_msg, cv_bridge, target_size=None, func_preprocess_img=None):
Expand Down