Skip to content

Commit

Permalink
Merge pull request #33 from tryolabs/metrics_fix
Browse files Browse the repository at this point in the history
Small refactor to metrics
  • Loading branch information
joaqo authored Dec 22, 2020
2 parents dd7758a + 2c35553 commit 04f80dd
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 34 deletions.
4 changes: 3 additions & 1 deletion demos/yolov4/yolov4demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class YOLO:
def __init__(self, weightfile, use_cuda=True):
if use_cuda and not torch.cuda.is_available():
raise Exception("Selected use_cuda=True, but cuda is not available to Pytorch")
raise Exception(
"Selected use_cuda=True, but cuda is not available to Pytorch"
)
self.use_cuda = use_cuda
self.model = Yolov4(yolov4conv137weight=None, n_classes=80, inference=True)
pretrained_dict = torch.load(
Expand Down
13 changes: 9 additions & 4 deletions norfair/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def draw_points(
detections: Sequence["Detection"],
radius: Optional[int] = None,
thickness: Optional[int] = None,
color: Optional[Tuple[int, int, int]] = None
color: Optional[Tuple[int, int, int]] = None,
):
if detections is None:
return
Expand Down Expand Up @@ -114,7 +114,10 @@ def draw_debug_metrics(
radius = int(frame_scale * 0.5)

for obj in objects:
if not (obj.last_detection.scores is None) and not (obj.last_detection.scores > draw_score_threshold).any():
if (
not (obj.last_detection.scores is None)
and not (obj.last_detection.scores > draw_score_threshold).any()
):
continue
if only_ids is not None:
if obj.id not in only_ids:
Expand All @@ -127,7 +130,9 @@ def draw_debug_metrics(
else:
text_color = color
draw_position = centroid(
obj.estimate[obj.last_detection.scores > draw_score_threshold] if obj.last_detection.scores is not None else obj.estimate
obj.estimate[obj.last_detection.scores > draw_score_threshold]
if obj.last_detection.scores is not None
else obj.estimate
)

for point in obj.estimate:
Expand Down Expand Up @@ -284,6 +289,6 @@ def random(obj_id: int) -> Tuple[int, int, int]:
c
for c in Color.__dict__.keys()
if c[:2] != "__"
and c not in ("random", "red", "white", "grey", "black", "silver")
and c not in ("random", "red", "white", "grey", "black", "silver")
]
return getattr(Color, color_list[obj_id % len(color_list)])
27 changes: 14 additions & 13 deletions norfair/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@

class InformationFile:
def __init__(self, file_path):
self.path = file_path
with open(file_path, "r") as myfile:
self.file = myfile.read()
file = myfile.read()
self.lines = file.splitlines()

def search(self, variable_name):
index_position_on_this_document = self.file.find(variable_name)
index_position_on_this_document = index_position_on_this_document + len(
variable_name
)
while not self.file[index_position_on_this_document].isdigit():
index_position_on_this_document += 1
value_string = ""
while self.file[index_position_on_this_document].isdigit():
value_string += self.file[index_position_on_this_document]
index_position_on_this_document += 1
return int(value_string)
for line in self.lines:
if line[: len(variable_name)] == variable_name:
result = line[len(variable_name) + 1 :]
break
else:
raise ValueError(f"Couldn't find '{variable_name}' in {self.path}")
if result.isdigit():
return int(result)
else:
return result


class PredictionsTextFile:
Expand Down Expand Up @@ -320,7 +321,7 @@ def eval_motChallenge(matrixes_predictions, paths, metrics=None, generate_overal

accs, names = compare_dataframes(gt, ts)

if metrics == None:
if metrics is None:
metrics = list(mm.metrics.motchallenge_metrics)
mm.lap.default_solver = "scipy"
print("Computing metrics...")
Expand Down
33 changes: 22 additions & 11 deletions norfair/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@ def __init__(
self.hit_inertia_max = hit_inertia_max

if initialization_delay is None:
self.initialization_delay = int((self.hit_inertia_max - self.hit_inertia_min) / 2)
elif initialization_delay < 0 or initialization_delay > self.hit_inertia_max - self.hit_inertia_min:
self.initialization_delay = int(
(self.hit_inertia_max - self.hit_inertia_min) / 2
)
elif (
initialization_delay < 0
or initialization_delay > self.hit_inertia_max - self.hit_inertia_min
):
raise ValueError(
f"Argument 'initialization_delay' for 'Tracker' class should be an int between 0 and (hit_inertia_max - hit_inertia_min = {hit_inertia_max - hit_inertia_min}). The selected value is {initialization_delay}.\n"
)
Expand All @@ -38,8 +43,7 @@ def __init__(
self.point_transience = point_transience
TrackedObject.count = 0

def update(self, detections: Optional[List["Detection"]] = None,
period: int = 1):
def update(self, detections: Optional[List["Detection"]] = None, period: int = 1):
self.period = period

# Remove stale trackers and make candidate object real if it has hit inertia
Expand Down Expand Up @@ -75,8 +79,11 @@ def update(self, detections: Optional[List["Detection"]] = None,

return [p for p in self.tracked_objects if not p.is_initializing]

def update_objects_in_place(self, objects: Sequence["TrackedObject"],
detections: Optional[List["Detection"]]):
def update_objects_in_place(
self,
objects: Sequence["TrackedObject"],
detections: Optional[List["Detection"]],
):
if detections is not None and len(detections) > 0:
distance_matrix = np.ones((len(detections), len(objects)), dtype=np.float32)
distance_matrix *= self.distance_threshold + 1
Expand Down Expand Up @@ -208,14 +215,18 @@ def __init__(
self.detection_threshold: float = detection_threshold
self.initial_period: int = period
self.hit_counter: int = hit_inertia_min + period
self.point_hit_counter: np.ndarray = np.ones(self.num_points) * self.point_hit_inertia_min
self.point_hit_counter: np.ndarray = (
np.ones(self.num_points) * self.point_hit_inertia_min
)
self.last_distance: Optional[float] = None
self.current_min_distance: Optional[float] = None
self.last_detection: "Detection" = initial_detection
self.age: int = 0
self.is_initializing_flag: bool = True
self.id: Optional[int] = None
self.initializing_id: int = TrackedObject.initializing_count # Just for debugging
self.initializing_id: int = (
TrackedObject.initializing_count
) # Just for debugging
TrackedObject.initializing_count += 1
self.setup_filter(initial_detection.points)
self.detected_at_least_once_points = np.array([False] * self.num_points)
Expand Down Expand Up @@ -281,7 +292,7 @@ def has_inertia(self):
@property
def estimate(self):
positions = self.filter.x.T.flatten()[: self.dim_z].reshape(-1, 2)
velocities = self.filter.x.T.flatten()[self.dim_z:].reshape(-1, 2)
velocities = self.filter.x.T.flatten()[self.dim_z :].reshape(-1, 2)
return positions

@property
Expand Down Expand Up @@ -315,7 +326,7 @@ def hit(self, detection: "Detection", period: int = 1):
self.point_hit_counter += 2 * period
self.point_hit_counter[
self.point_hit_counter >= self.point_hit_inertia_max
] = self.point_hit_inertia_max
] = self.point_hit_inertia_max
self.point_hit_counter[self.point_hit_counter < 0] = 0
H_vel = np.zeros(H_pos.shape) # But we don't directly measure velocity
H = np.hstack([H_pos, H_vel])
Expand All @@ -330,7 +341,7 @@ def hit(self, detection: "Detection", period: int = 1):
detected_at_least_once_mask = np.array(
[[m, m] for m in self.detected_at_least_once_points]
).flatten()
self.filter.x[self.dim_z:][np.logical_not(detected_at_least_once_mask)] = 0
self.filter.x[self.dim_z :][np.logical_not(detected_at_least_once_mask)] = 0
self.detected_at_least_once_points = np.logical_or(
self.detected_at_least_once_points, points_over_threshold_mask
)
Expand Down
2 changes: 2 additions & 0 deletions norfair/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def validate_points(points: np.array) -> np.array:
print_detection_error_message_and_exit(points)
return points


def print_detection_error_message_and_exit(points):
print("\n[red]INPUT ERROR:[/red]")
print(
Expand All @@ -27,6 +28,7 @@ def print_detection_error_message_and_exit(points):
print("https://github.com/tryolabs/norfair/tree/master/docs#detection\n")
exit()


def print_objects_as_table(tracked_objects: Sequence):
"""Used for helping in debugging"""
print()
Expand Down
18 changes: 13 additions & 5 deletions norfair/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ def __init__(
self.input_path = os.path.expanduser(self.input_path)
if not os.path.isfile(self.input_path):
self._fail(
f"[bold red]Error:[/bold red] File '{self.input_path}' does not exist.")
f"[bold red]Error:[/bold red] File '{self.input_path}' does not exist."
)
self.video_capture = cv2.VideoCapture(self.input_path)
total_frames = int(self.video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
self._fail(
f"[bold red]Error:[/bold red] '{self.input_path}' does not seem to be a video file supported by OpenCV. If the video file is not the problem, please check that your OpenCV installation is working correctly.")
f"[bold red]Error:[/bold red] '{self.input_path}' does not seem to be a video file supported by OpenCV. If the video file is not the problem, please check that your OpenCV installation is working correctly."
)
description = os.path.basename(self.input_path)
else:
self.video_capture = cv2.VideoCapture(self.camera)
Expand Down Expand Up @@ -185,7 +187,9 @@ def get_codec_fourcc(self, filename: str) -> Optional[str]:
f"[yellow]{filename}[/yellow]\n"
f"Please use '.mp4', '.avi', or provide a custom OpenCV fourcc codec name."
)
return None # Had to add this return to make mypya happy. I don't like this.
return (
None # Had to add this return to make mypya happy. I don't like this.
)

def abbreviate_description(self, description: str) -> str:
"""Conditionally abbreviate description so that progress bar fits in small terminals"""
Expand All @@ -198,7 +202,7 @@ def abbreviate_description(self, description: str) -> str:
else:
return "{} ... {}".format(
description[: space_for_description // 2 - 3],
description[-space_for_description // 2 + 3:],
description[-space_for_description // 2 + 3 :],
)


Expand Down Expand Up @@ -234,6 +238,8 @@ def __init__(self, input_path, save_path=".", information_file=None):
self.input_path = input_path
self.frame_number = 1
self.video = cv2.VideoWriter(video_path, fourcc, fps, image_size) # Video file
self.image_extension = information_file.search("imExt")
self.image_directory = information_file.search("imDir")

def __iter__(self):
self.frame_number = 1
Expand All @@ -242,7 +248,9 @@ def __iter__(self):
def __next__(self):
if self.frame_number <= self.length:
frame_path = os.path.join(
self.input_path, "img1", str(self.frame_number).zfill(6) + ".jpg"
self.input_path,
self.image_directory,
str(self.frame_number).zfill(6) + self.image_extension,
)
self.frame_number += 1

Expand Down

0 comments on commit 04f80dd

Please sign in to comment.