Skip to content

Commit

Permalink
[Feature] lazy implementation of video predictions (#1621)
Browse files Browse the repository at this point in the history
* Draft code for replacing list predictions with generator for video, mp4

* Changed types, united save_mp4, rewrote save_gif

* Returned correct type inside iterables in Video(Detection)Prediction

* Changed types to more correct ones

* Provided similar changes to PoseEstimationVideoPrediction

* Added tests of save and show for video predictions

* Added example script for PE and changed detection example script

* Removed test warning filter

* Removed unused import

* Removed duplicated import

* Removed show() from video tests, not available

* Fixed pretrained weights flag in test video

* Replaced link with an arg for video path

* Added a documentation line regarding the video predictions

* Changed word in progress bar example

---------

Co-authored-by: Eugene Khvedchenya <[email protected]>
  • Loading branch information
philmarchenko and BloodAxe authored Nov 15, 2023
1 parent 83eded4 commit d925039
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 68 deletions.
20 changes: 12 additions & 8 deletions documentation/source/ModelPredictions.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Using Pretrained Models for Predictions
# Using Pretrained Models for Predictions

In this tutorial, we will demonstrate how to use the `model.predict()` method for object detection tasks.

Expand All @@ -10,7 +10,7 @@ The model used in this tutorial is [YOLO-NAS](YoloNASQuickstart.md), pre-trained

## Supported Media Formats

A `mode.predict()` method is built to handle multiple data formats and types.
A `mode.predict()` method is built to handle multiple data formats and types.
Here is the full list of what `predict()` method can handle:

| Argument Semantics | Argument Type | Supported layout | Example | Notes |
Expand All @@ -25,13 +25,13 @@ Here is the full list of what `predict()` method can handle:
| 3-dimensional Torch Tensor | `torch.Tensor` | `[H, W, C]` or `[C, H, W]` | `predict(torch.zeros((480, 640, 3), dtype=torch.uint8))` | Tensor layout (HWC or CHW) is inferred w.r.t to number of input channels of underlying model |
| 4-dimensional Torch Tensor | `torch.Tensor` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(torch.zeros((4, 480, 640, 3), dtype=torch.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model |

**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**.
**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**.
This means that the input tensors **should not** be normalized beforehand.
Here is the example of **incorrect** code of using `model.predict()`:

```python
# Incorrect code example. Do not use it.
from super_gradients.training import dataloaders
from super_gradients.training import dataloaders
from super_gradients.common.object_names import Models
from super_gradients.training import models

Expand Down Expand Up @@ -139,6 +139,10 @@ You can also directly access a specific image prediction by referencing its inde
## Detect Objects in Animated GIFs and Videos
The processing for both gif and videos is similar, as they are treated as videos internally. You can use the same `model.predict()` method as before, but pass the path to a GIF or video file instead. The results can be saved as either a `.gif` or `.mp4`.

To mitigate Out-of-Memory (OOM) errors, the `model.predict()` method for video returns a generator object. This allows the video frames to be processed sequentially, minimizing memory usage. It's important to be aware that model inference in this mode will be slower since batching is not supported.

Consequently, you need to invoke `model.predict()` before each `show()` and `save()` call.

### Load an Animated GIF or Video File
Let's load an animated GIF or a video file and pass it to the `model.predict()` method:

Expand Down Expand Up @@ -170,7 +174,7 @@ media_predictions.save("output_video.mp4") # Save as .mp4
The number of Frames Per Second (FPS) at which the model processes the gif/video can be seen directly next to the loading bar when running `model.predict('my_video.mp4')`.

In the following example, the FPS is 39.49it/s (i.e. fps)
`Predicting Video: 100%|███████████████████████| 306/306 [00:07<00:00, 39.49it/s]`
`Processing Video: 100%|███████████████████████| 306/306 [00:07<00:00, 39.49it/s]`

Note that the video/gif will be saved with original FPS (i.e. `media_predictions.fps`).

Expand Down Expand Up @@ -237,13 +241,13 @@ predictions = model.predict(image, skip_image_resizing=True)
The following images illustrate the difference in detection results with and without resizing.

#### Original Image
![Original Image](images/detection_example_beach_raw_image.jpeg)
![Original Image](images/detection_example_beach_raw_image.jpeg)
*This is the raw image before any processing.*

#### Image Processed with Standard Resizing (640x640)
![Resized Image](images/detection_example_beach_resized_predictions.jpg)
![Resized Image](images/detection_example_beach_resized_predictions.jpg)
*This image shows the detection results after resizing the image to the model's trained size of 640x640.*

#### Image Processed in Original Size
![Original Size Image](images/detection_example_beach_raw_image_prediction.jpg)
![Original Size Image](images/detection_example_beach_raw_image_prediction.jpg)
*Here, the image is processed in its original size, demonstrating how the model performs without resizing. Notice the differences in object detection and details compared to the resized version.*
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
f.write(response.content)

predictions = model.predict(video_path)
predictions.show()
predictions.save("pose_elephant_flip_prediction.mp4")

predictions = model.predict(video_path)
predictions.save("pose_elephant_flip_prediction.gif") # Can also be saved as a gif.

predictions = model.predict(video_path)
predictions.show()
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
from super_gradients.training import models

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path_to_video", type=str)

if __name__ == "__main__":
args = parser.parse_args()

# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
model = models.get("yolo_nas_pose_l", pretrained_weights="coco_pose")

# We want to use cuda if available to speed up inference.
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

predictions = model.predict(args.path_to_video)
predictions.save(f"{args.path_to_video.split('/')[-1]}_prediction.mp4")

predictions = model.predict(args.path_to_video)
predictions.save(f"{args.path_to_video.split('/')[-1]}_prediction.gif") # Can also be saved as a gif.

predictions = model.predict(args.path_to_video)
predictions.show()
13 changes: 6 additions & 7 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ClassificationPrediction,
)
from super_gradients.training.utils.utils import generate_batch, infer_model_device, resolve_torch_device
from super_gradients.training.utils.media.video import load_video, includes_video_extension
from super_gradients.training.utils.media.video import includes_video_extension, lazy_load_video
from super_gradients.training.utils.media.image import ImageSource, check_image_typing
from super_gradients.training.utils.media.stream import WebcamStreaming
from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
Expand Down Expand Up @@ -134,9 +134,10 @@ def predict_video(self, video_path: str, batch_size: Optional[int] = 32) -> Vide
:param batch_size: The size of each batch.
:return: Results of the prediction.
"""
video_frames, fps = load_video(file_path=video_path)
video_frames, fps, num_frames = lazy_load_video(file_path=video_path)
result_generator = self._generate_prediction_result(images=video_frames, batch_size=batch_size)
return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames))
return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=num_frames)
# return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames))

def predict_webcam(self) -> None:
"""Predict using webcam"""
Expand Down Expand Up @@ -335,8 +336,7 @@ def _combine_image_prediction_to_images(
def _combine_image_prediction_to_video(
self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
) -> VideoDetectionPrediction:
images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")]
return VideoDetectionPrediction(_images_prediction_lst=images_predictions, fps=fps)
return VideoDetectionPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images)


class PoseEstimationPipeline(Pipeline):
Expand Down Expand Up @@ -419,8 +419,7 @@ def _combine_image_prediction_to_images(
def _combine_image_prediction_to_video(
self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
) -> VideoPoseEstimationPrediction:
images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")]
return VideoPoseEstimationPrediction(_images_prediction_lst=images_predictions, fps=fps)
return VideoPoseEstimationPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images)


class ClassificationPipeline(Pipeline):
Expand Down
93 changes: 64 additions & 29 deletions src/super_gradients/training/utils/media/video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Iterable, Iterator
import cv2
import PIL

Expand Down Expand Up @@ -30,6 +30,23 @@ def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[n
return frames, fps


def lazy_load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[Iterator[np.ndarray], int, int]:
"""Open a video file and returns a generator which yields frames.
:param file_path: Path to the video file.
:param max_frames: Optional, maximum number of frames to extract.
:return:
- Generator yielding frames representing the video, each in (H, W, C), RGB.
- Frames per Second (FPS).
- Amount of frames in video.
"""
cap = _open_video(file_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frames = _lazy_extract_frames(cap, max_frames)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
return frames, fps, num_frames


def _open_video(file_path: str) -> cv2.VideoCapture:
"""Open a video file.
Expand Down Expand Up @@ -61,6 +78,27 @@ def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) ->
return frames


def _lazy_extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) -> Iterator[np.ndarray]:
"""Lazy implementation of frames extraction from an opened video capture object.
NOTE: Releases the capture object.
:param cap: Opened video capture object.
:param max_frames: Optional maximum number of frames to extract.
:return: Generator yielding frames representing the video, each in (H, W, C), RGB.
"""
frames_counter = 0

while frames_counter != max_frames:
frame_read_success, frame = cap.read()
if not frame_read_success:
break

frames_counter += 1
yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

cap.release()


def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.
Expand All @@ -78,64 +116,61 @@ def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
save_mp4(output_path, frames, fps)


def save_gif(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally in .gif format.
def save_gif(output_path: str, frames: Iterable[np.ndarray], fps: int) -> None:
"""Save a video locally in .gif format. Safe for generator of frames object.
:param output_path: Where the video will be saved
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
"""
frame_iter_obj = iter(frames)
pil_frames_iter_obj = map(PIL.Image.fromarray, frame_iter_obj)

frames_pil = [PIL.Image.fromarray(frame) for frame in frames]
first_frame = next(pil_frames_iter_obj)

frames_pil[0].save(output_path, save_all=True, append_images=frames_pil[1:], duration=int(1000 / fps), loop=0)
first_frame.save(output_path, save_all=True, append_images=pil_frames_iter_obj, duration=int(1000 / fps), loop=0)


def save_mp4(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally in .mp4 format.
def save_mp4(output_path: str, frames: Iterable[np.ndarray], fps: int) -> None:
"""Save a video locally in .mp4 format. Safe for generator of frames object.
:param output_path: Where the video will be saved
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
"""
video_height, video_width = _validate_frames(frames)

video_writer = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(video_width, video_height),
)
video_height, video_width, video_writer = None, None, None

for frame in frames:
if video_height is None:
video_height, video_width = frame.shape[:2]
video_writer = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(video_width, video_height),
)
_validate_frame(frame, video_height, video_width)
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

video_writer.release()


def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]:
"""Validate the frames to make sure that every frame has the same size and includes the channel dimension. (i.e. (H, W, C))
def _validate_frame(frame: np.ndarray, control_height: int, control_width: int) -> None:
"""Validate the frame to make sure it has the correct size and includes the channel dimension. (i.e. (H, W, C))
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:return: (Height, Weight) of the video.
:param frame: Single frame from the video, in (H, W, C), RGB.
"""
min_height = min(frame.shape[0] for frame in frames)
max_height = max(frame.shape[0] for frame in frames)

min_width = min(frame.shape[1] for frame in frames)
max_width = max(frame.shape[1] for frame in frames)
height, width = frame.shape[:2]

if (min_height, min_width) != (max_height, max_width):
if (height, width) != (control_height, control_width):
raise RuntimeError(
f"Your video is made of frames that have (height, width) going from ({min_height}, {min_width}) to ({max_height}, {max_width}).\n"
f"Current frame has resolution {height}x{width} but {control_height}x{control_width} is expected!"
f"Please make sure that all the frames have the same shape."
)

if set(frame.ndim for frame in frames) != {3} or set(frame.shape[-1] for frame in frames) != {3}:
if frame.ndim != 3:
raise RuntimeError("Your frames must include 3 channels.")

return max_height, max_width


def show_video_from_disk(video_path: str, window_name: str = "Prediction"):
"""Display a video from disk using OpenCV.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from typing import List
from typing import List, Iterator

import numpy as np

Expand All @@ -9,6 +9,8 @@
from super_gradients.training.utils.media.video import show_video_from_frames, save_video
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization

from tqdm import tqdm


@dataclass
class ImagePoseEstimationPrediction(ImagePrediction):
Expand Down Expand Up @@ -210,8 +212,9 @@ class VideoPoseEstimationPrediction(VideoPredictions):
:att fps: Frames per second of the video
"""

_images_prediction_lst: List[ImagePoseEstimationPrediction]
_images_prediction_gen: Iterator[ImagePoseEstimationPrediction]
fps: int
n_frames: int

def draw(
self,
Expand All @@ -221,7 +224,7 @@ def draw(
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> List[np.ndarray]:
) -> Iterator[np.ndarray]:
"""Draw the predicted bboxes on the images.
:param output_folder: Folder path, where the images will be saved.
Expand All @@ -236,20 +239,18 @@ def draw(
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
:return: List of images with predicted bboxes. Note that this does not modify the original image.
:return: Iterator of images with predicted bboxes. Note that this does not modify the original image.
"""
frames_with_bbox = [
result.draw(

for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"):
yield result.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
)
for result in self._images_prediction_lst
]
return frames_with_bbox

def show(
self,
Expand Down
Loading

0 comments on commit d925039

Please sign in to comment.