Skip to content

Commit

Permalink
Improve documentation of live_prediction script
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalior committed Aug 28, 2018
1 parent ada3f7e commit cbe4fd0
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions live_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,12 @@ def main(args):
logging.debug("Number of tracks: {}".format(len(tracks)))
track_people_time = time() - track_people_start

predict_people_start = time()
# Extract the latest frames, as we don't want to copy
# too much data here, and we've already predicted for the rest
processor.tracks = [copy.deepcopy(t.copy(-50)) for t in tracks]
processor.post_process_tracks()

predict_people_start = time()

predictions = [predict_per_track(t, classifier) for t in processor.tracks]

valid_predictions = filter_bad_predictions(
Expand All @@ -72,6 +71,7 @@ def main(args):

logging.info("Predict time: {:.3f}, Track time: {:.3f}".format(
predict_people_time, track_people_time))
log_predictions(predictions, no_stop_predictions)
track_people_start = time()


Expand Down Expand Up @@ -145,7 +145,7 @@ def predict_no_stop(track, confidence_threshold, stop_threshold=10):
track = track.copy(-200)
chunks, chunk_frames = track.divide_into_chunks(len(track) - 1, 0)

speed_prediction = speed_no_stop_prediction(track, chunk, stop_threshold)
speed_prediction = speed_no_stop_prediction(track, chunks, stop_threshold)

confidence = max(classifier_prediction, speed_prediction)

Expand All @@ -155,6 +155,8 @@ def predict_no_stop(track, confidence_threshold, stop_threshold=10):
label = "speed not stopped"
elif classifier_prediction > confidence_threshold:
label = "classifier not stopped"
else:
label = "has stopped"

position = tuple(chunks[0, -1, 1, :2].astype(np.int))
prediction_tuple = (label, confidence, position, chunks[0], chunk_frames[0])
Expand Down Expand Up @@ -217,13 +219,21 @@ def log_predictions(predictions, no_stop_predictions):

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Generates action predictions live given a video and a pre-trained classifier.')
description=('Generates action predictions live given a video and a pre-trained classifier. '
'Uses Tracker.tracker.video_generator which yields every track every frame, '
'from which it predicts the class of action using the pre-trained classifier. '
'To get a better prediction, it takes the latest 50, 30, 25, and 20 frames '
'as chunks and selects the likliest prediction among the five * n_classes. '
'It also predicts if a person has not stopped moving (e.g. if they are moving '
'through a self-checkout area without scanning anything) by checking if '
'a proportion of the latest identified actions for a track/person is moving.'))

parser.add_argument('--classifier', type=str,
help='Path to a .pkl file with a pre-trained action recognition classifier.')
parser.add_argument('--video', type=str,
help='Path to video file to predict actions for.')
parser.add_argument('--model-path', type=str, default='../openpose/models/',
help='The model path for the caffe implementation.')
help='The model path for OpenPose.')
parser.add_argument('--confidence-threshold', type=float, default=0.8,
help='Threshold for how confident the model should be in each prediction.')

Expand Down

0 comments on commit cbe4fd0

Please sign in to comment.