Skip to content

Commit

Permalink
Extract prediction to method in live_prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalior committed Aug 28, 2018
1 parent e77346c commit 5f0450c
Showing 1 changed file with 35 additions and 34 deletions.
69 changes: 35 additions & 34 deletions live_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@ def main(args):
copyfile(args.video, tmp_video_file)

classifier = joblib.load(args.classifier)

detector = CaffeOpenpose(args.model_path)
tracker = Tracker(detector, out_dir=args.out_directory)

processor = PostProcessor()

classes = classifier.classes_
logging.info("Classes: {}".format(classes))
logging.info("Classes: {}".format(classifier.classes_))

track_people_start = time()
valid_predictions = []
track_people_start = time()
for tracks, img, current_frame in tracker.video_generator(args.video, args.draw_frames):
# Don't predict every frame, not enough has changed for it to be valuable.
if current_frame % 20 != 0 or len(tracks) <= 0:
Expand All @@ -43,38 +41,47 @@ def main(args):
tracks = [track for track in tracks
if track.recently_updated(current_frame)]

logging.debug("Number of tracks: {}".format(len(tracks)))
track_people_time = time() - track_people_start
logging.debug("Number of tracks: {}".format(len(tracks)))

predict_people_start = time()
# Extract the latest frames, as we don't want to copy
# too much data here, and we've already predicted for the rest
processor.tracks = [copy.deepcopy(t.copy(-50)) for t in tracks]
processor.post_process_tracks()

predictions = [predict_per_track(t, classifier) for t in processor.tracks]

valid_predictions = filter_bad_predictions(
predictions, args.confidence_threshold, classes)
save_predictions_to_track(predictions, classes, tracks, current_frame)
valid_predictions = predict(tracks, classifier, current_frame, args.confidence_threshold)

predict_people_time = time() - predict_people_start

no_stop_predictions = [predict_no_stop(track, args.confidence_threshold)
for track in tracks]

for t in [t for p, t in no_stop_predictions if p]:
valid_predictions.append(t)

write_predictions(valid_predictions, img)
save_predictions(valid_predictions, args.video, tmp_video_file, args.out_directory)

logging.info("Predict time: {:.3f}, Track time: {:.3f}".format(
predict_people_time, track_people_time))
log_predictions(predictions, no_stop_predictions)
track_people_start = time()


def predict(tracks, classifier, current_frame, confidence_threshold):
# Extract the latest frames, as we don't want to copy
# too much data here, and we've already predicted for the rest
processor = PostProcessor()
processor.tracks = [copy.deepcopy(t.copy(-50)) for t in tracks]
processor.post_process_tracks()

predictions = [predict_per_track(t, classifier) for t in processor.tracks]

valid_predictions = filter_bad_predictions(
predictions, confidence_threshold, classifier.classes_)
save_predictions_to_track(predictions, classifier.classes_, tracks, current_frame)

no_stop_predictions = [predict_no_stop(track, confidence_threshold)
for track in tracks]

for t in [t for p, t in no_stop_predictions if p]:
valid_predictions.append(t)

log_predictions(predictions, no_stop_predictions, classifier.classes_)

return valid_predictions


def predict_per_track(track, classifier):
all_chunks = []
all_frames = []
Expand Down Expand Up @@ -138,10 +145,12 @@ def write_chunk_to_file(video_name, video, frames, chunk, label, out_dir, i):

def predict_no_stop(track, confidence_threshold, stop_threshold=10):
if len(track) < 50:
return 0, ()
return False, ()

classifier_prediction = classifier_predict_no_stop(track, confidence_threshold)

# Only check last 200 frames as person could have been doing something else
# before that. Makes the prediction a bit fragile.
track = track.copy(-200)
chunks, chunk_frames = track.divide_into_chunks(len(track) - 1, 0)

Expand Down Expand Up @@ -176,9 +185,6 @@ def classifier_predict_no_stop(track, confidence_threshold):


def speed_no_stop_prediction(track, chunks, stop_threshold):
# Only check last 200 frames as person could have been doing something else
# before that. Makes the prediction a bit fragile.

keypoint_speed = transforms.Speed().fit_transform(chunks)[0]
frame_speed = np.mean(keypoint_speed[:, :, :2], axis=1)
frame_speed = np.linalg.norm(frame_speed, axis=1)
Expand All @@ -191,20 +197,15 @@ def speed_no_stop_prediction(track, chunks, stop_threshold):

n_movement_frames = np.count_nonzero(frame_speed[first_movement_index:] > stop_threshold)

# Calculate how many of the last moving frames have to have had movement
# for us to predict the person did not stop to do anything.
# The 4 here is arbitrary and might make the prediction a bit fragile
movement_length = len(track) - first_movement_index
n_movement_frames_for_no_stop = movement_length - movement_length / 4

# Make sure the speed is measured over at least 100 frames.
if len(track) - first_movement_index < 100:
return 0

return n_movement_frames / n_movement_frames_for_no_stop
movement_length = len(track) - first_movement_index
return n_movement_frames / movement_length


def log_predictions(predictions, no_stop_predictions):
def log_predictions(predictions, no_stop_predictions, classes):
prints = []
for _, _, prediction in predictions:
prints.append(get_best_pred(prediction, classes))
Expand Down

0 comments on commit 5f0450c

Please sign in to comment.