Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] lazy implementation of video predictions #1621

Merged
merged 24 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
06cf002
Draft code for replacing list predictions with generator for video, mp4
philmarchenko Nov 7, 2023
5d6423c
Changed types, united save_mp4, rewrote save_gif
philmarchenko Nov 8, 2023
1835cf3
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 8, 2023
e040b47
Returned correct type inside iterables in Video(Detection)Prediction
philmarchenko Nov 8, 2023
4e5caac
Changed types to more correct ones
philmarchenko Nov 8, 2023
cade080
Provided similar changes to PoseEstimationVideoPrediction
philmarchenko Nov 8, 2023
0b6ec0e
Merge branch 'feature/video-save-show-fix' of github.com:hakuryuu96/s…
philmarchenko Nov 8, 2023
fe6091e
Merge branch 'master' into feature/video-save-show-fix
BloodAxe Nov 13, 2023
ccf2357
Added tests of save and show for video predictions
philmarchenko Nov 13, 2023
51c6d0f
Added example script for PE and changed detection example script
philmarchenko Nov 13, 2023
7f68cb7
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 13, 2023
4095b83
Removed test warning filter
philmarchenko Nov 13, 2023
d6a3390
Merge branch 'feature/video-save-show-fix' of github.com:hakuryuu96/s…
philmarchenko Nov 13, 2023
d11b121
Removed unused import
philmarchenko Nov 13, 2023
f3bce78
Removed duplicated import
philmarchenko Nov 13, 2023
fe7cb37
Removed show() from video tests, not available
philmarchenko Nov 13, 2023
ec705cd
Fixed pretrained weights flag in test video
philmarchenko Nov 13, 2023
1e25265
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 13, 2023
a6096e5
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 14, 2023
b1a3e1b
Merge branch 'master' into feature/video-save-show-fix
BloodAxe Nov 14, 2023
e81ec61
Replaced link with an arg for video path
philmarchenko Nov 15, 2023
a79a58f
Added a documentation line regarding the video predictions
philmarchenko Nov 15, 2023
6f2c467
Changed word in progress bar example
philmarchenko Nov 15, 2023
8f4cb05
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions documentation/source/ModelPredictions.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Using Pretrained Models for Predictions
# Using Pretrained Models for Predictions

In this tutorial, we will demonstrate how to use the `model.predict()` method for object detection tasks.

Expand All @@ -10,7 +10,7 @@ The model used in this tutorial is [YOLO-NAS](YoloNASQuickstart.md), pre-trained

## Supported Media Formats

A `mode.predict()` method is built to handle multiple data formats and types.
A `mode.predict()` method is built to handle multiple data formats and types.
Here is the full list of what `predict()` method can handle:

| Argument Semantics | Argument Type | Supported layout | Example | Notes |
Expand All @@ -25,13 +25,13 @@ Here is the full list of what `predict()` method can handle:
| 3-dimensional Torch Tensor | `torch.Tensor` | `[H, W, C]` or `[C, H, W]` | `predict(torch.zeros((480, 640, 3), dtype=torch.uint8))` | Tensor layout (HWC or CHW) is inferred w.r.t to number of input channels of underlying model |
| 4-dimensional Torch Tensor | `torch.Tensor` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(torch.zeros((4, 480, 640, 3), dtype=torch.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model |

**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**.
**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**.
This means that the input tensors **should not** be normalized beforehand.
Here is the example of **incorrect** code of using `model.predict()`:

```python
# Incorrect code example. Do not use it.
from super_gradients.training import dataloaders
from super_gradients.training import dataloaders
from super_gradients.common.object_names import Models
from super_gradients.training import models

Expand Down Expand Up @@ -139,6 +139,10 @@ You can also directly access a specific image prediction by referencing its inde
## Detect Objects in Animated GIFs and Videos
The processing for both gif and videos is similar, as they are treated as videos internally. You can use the same `model.predict()` method as before, but pass the path to a GIF or video file instead. The results can be saved as either a `.gif` or `.mp4`.

To mitigate Out-of-Memory (OOM) errors, the `model.predict()` method for video returns a generator object. This allows the video frames to be processed sequentially, minimizing memory usage. It's important to be aware that model inference in this mode will be slower since batching is not supported.

Consequently, you need to invoke `model.predict()` before each `show()` and `save()` call.

### Load an Animated GIF or Video File
Let's load an animated GIF or a video file and pass it to the `model.predict()` method:

Expand Down Expand Up @@ -170,7 +174,7 @@ media_predictions.save("output_video.mp4") # Save as .mp4
The number of Frames Per Second (FPS) at which the model processes the gif/video can be seen directly next to the loading bar when running `model.predict('my_video.mp4')`.

In the following example, the FPS is 39.49it/s (i.e. fps)
`Predicting Video: 100%|███████████████████████| 306/306 [00:07<00:00, 39.49it/s]`
`Processing Video: 100%|███████████████████████| 306/306 [00:07<00:00, 39.49it/s]`

Note that the video/gif will be saved with original FPS (i.e. `media_predictions.fps`).

Expand Down Expand Up @@ -237,13 +241,13 @@ predictions = model.predict(image, skip_image_resizing=True)
The following images illustrate the difference in detection results with and without resizing.

#### Original Image
![Original Image](images/detection_example_beach_raw_image.jpeg)
![Original Image](images/detection_example_beach_raw_image.jpeg)
*This is the raw image before any processing.*

#### Image Processed with Standard Resizing (640x640)
![Resized Image](images/detection_example_beach_resized_predictions.jpg)
![Resized Image](images/detection_example_beach_resized_predictions.jpg)
*This image shows the detection results after resizing the image to the model's trained size of 640x640.*

#### Image Processed in Original Size
![Original Size Image](images/detection_example_beach_raw_image_prediction.jpg)
![Original Size Image](images/detection_example_beach_raw_image_prediction.jpg)
*Here, the image is processed in its original size, demonstrating how the model performs without resizing. Notice the differences in object detection and details compared to the resized version.*
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
f.write(response.content)

predictions = model.predict(video_path)
predictions.show()
predictions.save("pose_elephant_flip_prediction.mp4")

predictions = model.predict(video_path)
predictions.save("pose_elephant_flip_prediction.gif") # Can also be saved as a gif.

predictions = model.predict(video_path)
predictions.show()
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
from super_gradients.training import models

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path_to_video", type=str)

if __name__ == "__main__":
args = parser.parse_args()

# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
model = models.get("yolo_nas_pose_l", pretrained_weights="coco_pose")

# We want to use cuda if available to speed up inference.
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

predictions = model.predict(args.path_to_video)
predictions.save(f"{args.path_to_video.split('/')[-1]}_prediction.mp4")

predictions = model.predict(args.path_to_video)
predictions.save(f"{args.path_to_video.split('/')[-1]}_prediction.gif") # Can also be saved as a gif.

predictions = model.predict(args.path_to_video)
predictions.show()
13 changes: 6 additions & 7 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ClassificationPrediction,
)
from super_gradients.training.utils.utils import generate_batch, infer_model_device, resolve_torch_device
from super_gradients.training.utils.media.video import load_video, includes_video_extension
from super_gradients.training.utils.media.video import includes_video_extension, lazy_load_video
from super_gradients.training.utils.media.image import ImageSource, check_image_typing
from super_gradients.training.utils.media.stream import WebcamStreaming
from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
Expand Down Expand Up @@ -134,9 +134,10 @@ def predict_video(self, video_path: str, batch_size: Optional[int] = 32) -> Vide
:param batch_size: The size of each batch.
:return: Results of the prediction.
"""
video_frames, fps = load_video(file_path=video_path)
video_frames, fps, num_frames = lazy_load_video(file_path=video_path)
result_generator = self._generate_prediction_result(images=video_frames, batch_size=batch_size)
return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames))
return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=num_frames)
# return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames))

def predict_webcam(self) -> None:
"""Predict using webcam"""
Expand Down Expand Up @@ -335,8 +336,7 @@ def _combine_image_prediction_to_images(
def _combine_image_prediction_to_video(
self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
) -> VideoDetectionPrediction:
images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")]
return VideoDetectionPrediction(_images_prediction_lst=images_predictions, fps=fps)
return VideoDetectionPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images)


class PoseEstimationPipeline(Pipeline):
Expand Down Expand Up @@ -419,8 +419,7 @@ def _combine_image_prediction_to_images(
def _combine_image_prediction_to_video(
self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
) -> VideoPoseEstimationPrediction:
images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")]
return VideoPoseEstimationPrediction(_images_prediction_lst=images_predictions, fps=fps)
return VideoPoseEstimationPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images)


class ClassificationPipeline(Pipeline):
Expand Down
93 changes: 64 additions & 29 deletions src/super_gradients/training/utils/media/video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Iterable, Iterator
import cv2
import PIL

Expand Down Expand Up @@ -30,6 +30,23 @@ def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[n
return frames, fps


def lazy_load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[Iterator[np.ndarray], int, int]:
"""Open a video file and returns a generator which yields frames.

:param file_path: Path to the video file.
:param max_frames: Optional, maximum number of frames to extract.
:return:
- Generator yielding frames representing the video, each in (H, W, C), RGB.
- Frames per Second (FPS).
- Amount of frames in video.
"""
cap = _open_video(file_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frames = _lazy_extract_frames(cap, max_frames)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
return frames, fps, num_frames


def _open_video(file_path: str) -> cv2.VideoCapture:
"""Open a video file.

Expand Down Expand Up @@ -61,6 +78,27 @@ def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) ->
return frames


def _lazy_extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) -> Iterator[np.ndarray]:
"""Lazy implementation of frames extraction from an opened video capture object.
NOTE: Releases the capture object.

:param cap: Opened video capture object.
:param max_frames: Optional maximum number of frames to extract.
:return: Generator yielding frames representing the video, each in (H, W, C), RGB.
"""
frames_counter = 0

while frames_counter != max_frames:
frame_read_success, frame = cap.read()
if not frame_read_success:
break

frames_counter += 1
yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

cap.release()


def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.

Expand All @@ -78,64 +116,61 @@ def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
save_mp4(output_path, frames, fps)


def save_gif(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally in .gif format.
def save_gif(output_path: str, frames: Iterable[np.ndarray], fps: int) -> None:
"""Save a video locally in .gif format. Safe for generator of frames object.

:param output_path: Where the video will be saved
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
"""
frame_iter_obj = iter(frames)
pil_frames_iter_obj = map(PIL.Image.fromarray, frame_iter_obj)

frames_pil = [PIL.Image.fromarray(frame) for frame in frames]
first_frame = next(pil_frames_iter_obj)

frames_pil[0].save(output_path, save_all=True, append_images=frames_pil[1:], duration=int(1000 / fps), loop=0)
first_frame.save(output_path, save_all=True, append_images=pil_frames_iter_obj, duration=int(1000 / fps), loop=0)


def save_mp4(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally in .mp4 format.
def save_mp4(output_path: str, frames: Iterable[np.ndarray], fps: int) -> None:
"""Save a video locally in .mp4 format. Safe for generator of frames object.

:param output_path: Where the video will be saved
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
"""
video_height, video_width = _validate_frames(frames)

video_writer = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(video_width, video_height),
)
video_height, video_width, video_writer = None, None, None

for frame in frames:
if video_height is None:
video_height, video_width = frame.shape[:2]
video_writer = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(video_width, video_height),
)
_validate_frame(frame, video_height, video_width)
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

video_writer.release()


def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]:
"""Validate the frames to make sure that every frame has the same size and includes the channel dimension. (i.e. (H, W, C))
def _validate_frame(frame: np.ndarray, control_height: int, control_width: int) -> None:
"""Validate the frame to make sure it has the correct size and includes the channel dimension. (i.e. (H, W, C))

:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:return: (Height, Weight) of the video.
:param frame: Single frame from the video, in (H, W, C), RGB.
"""
min_height = min(frame.shape[0] for frame in frames)
max_height = max(frame.shape[0] for frame in frames)

min_width = min(frame.shape[1] for frame in frames)
max_width = max(frame.shape[1] for frame in frames)
height, width = frame.shape[:2]

if (min_height, min_width) != (max_height, max_width):
if (height, width) != (control_height, control_width):
raise RuntimeError(
f"Your video is made of frames that have (height, width) going from ({min_height}, {min_width}) to ({max_height}, {max_width}).\n"
f"Current frame has resolution {height}x{width} but {control_height}x{control_width} is expected!"
f"Please make sure that all the frames have the same shape."
)

if set(frame.ndim for frame in frames) != {3} or set(frame.shape[-1] for frame in frames) != {3}:
if frame.ndim != 3:
raise RuntimeError("Your frames must include 3 channels.")

return max_height, max_width


def show_video_from_disk(video_path: str, window_name: str = "Prediction"):
"""Display a video from disk using OpenCV.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from typing import List
from typing import List, Iterator

import numpy as np

Expand All @@ -9,6 +9,8 @@
from super_gradients.training.utils.media.video import show_video_from_frames, save_video
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization

from tqdm import tqdm


@dataclass
class ImagePoseEstimationPrediction(ImagePrediction):
Expand Down Expand Up @@ -210,8 +212,9 @@ class VideoPoseEstimationPrediction(VideoPredictions):
:att fps: Frames per second of the video
"""

_images_prediction_lst: List[ImagePoseEstimationPrediction]
_images_prediction_gen: Iterator[ImagePoseEstimationPrediction]
fps: int
n_frames: int

def draw(
self,
Expand All @@ -221,7 +224,7 @@ def draw(
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> List[np.ndarray]:
) -> Iterator[np.ndarray]:
"""Draw the predicted bboxes on the images.

:param output_folder: Folder path, where the images will be saved.
Expand All @@ -236,20 +239,18 @@ def draw(
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.

:return: List of images with predicted bboxes. Note that this does not modify the original image.
:return: Iterator of images with predicted bboxes. Note that this does not modify the original image.
"""
frames_with_bbox = [
result.draw(

for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"):
yield result.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
)
for result in self._images_prediction_lst
]
return frames_with_bbox

def show(
self,
Expand Down
Loading