From 45835ac86f0bd8c8b130d8aa5065ac1df226763b Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Mon, 10 Jul 2023 15:56:28 +0800 Subject: [PATCH] [Feature] Add Application 'Just dance' (#2528) --- projects/README.md | 6 + projects/just_dance/README.md | 34 + projects/just_dance/app.py | 109 +++ projects/just_dance/calculate_similarity.py | 105 +++ projects/just_dance/configs/_base_ | 1 + .../configs/rtmdet-nano_one-person.py | 3 + projects/just_dance/just_dance_demo.ipynb | 712 ++++++++++++++++++ projects/just_dance/process_video.py | 229 ++++++ projects/just_dance/utils.py | 106 +++ 9 files changed, 1305 insertions(+) create mode 100644 projects/just_dance/README.md create mode 100644 projects/just_dance/app.py create mode 100644 projects/just_dance/calculate_similarity.py create mode 120000 projects/just_dance/configs/_base_ create mode 100644 projects/just_dance/configs/rtmdet-nano_one-person.py create mode 100644 projects/just_dance/just_dance_demo.ipynb create mode 100644 projects/just_dance/process_video.py create mode 100644 projects/just_dance/utils.py diff --git a/projects/README.md b/projects/README.md index eacc217388..1089af8194 100644 --- a/projects/README.md +++ b/projects/README.md @@ -54,4 +54,10 @@ We also provide some documentation listed below to help you get started:
+- **[💃Just-Dance](./just_dance)**: Enhancing Dance scoring system for comparing dance performances in videos + +
+ +

+ - **What's next? Join the rank of *MMPose contributors* by creating a new project**! diff --git a/projects/just_dance/README.md b/projects/just_dance/README.md new file mode 100644 index 0000000000..1255996766 --- /dev/null +++ b/projects/just_dance/README.md @@ -0,0 +1,34 @@ +# Just Dance - A Simple Implementation + +This project presents a dance scoring system based on RTMPose. Users can compare the similarity between two dancers in different videos: one referred to as the "teacher video" and the other as the "student video." + +Here is an example of the output dance comparison: + +![output](https://github.com/open-mmlab/mmpose/assets/26127467/56d5c4d1-55d8-4222-b481-2418cc29a8d4) + +## Usage + +### Jupyter Notebook + +We provide a Jupyter Notebook [`just_dance_demo.ipynb`](./just_dance_demo.ipynb) that contains the complete process of dance comparison. It includes steps such as video FPS adjustment, pose estimation, snippet alignment, scoring, and the generation of the merged video. + +### CLI tool + +Users can simply run the following command to generate the comparison video: + +```shell +python process_video ${TEACHER_VIDEO} ${STUDENT_VIDEO} +``` + +### Gradio + +Users can also utilize Gradio to build an application using this system. We provide the script [`app.py`](./app.py). This application supports webcam input in addition to existing videos. To build this application, please follow these two steps: + +1. Install Gradio + ```shell + pip install gradio + ``` +2. Run the script [`app.py`](./app.py) + ```shell + python app.py + ``` diff --git a/projects/just_dance/app.py b/projects/just_dance/app.py new file mode 100644 index 0000000000..9b40c64fdd --- /dev/null +++ b/projects/just_dance/app.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import sys +from functools import partial +from typing import Optional + +project_path = os.path.join(os.path.dirname(os.path.abspath(__file__))) +mmpose_path = project_path.split('/projects', 1)[0] + +os.system('python -m pip install Openmim') +os.system('python -m mim install "mmcv>=2.0.0"') +os.system('python -m mim install mmengine') +os.system('python -m mim install "mmdet>=3.0.0"') +os.system(f'python -m mim install -e {mmpose_path}') + +os.environ['PATH'] = f"{os.environ['PATH']}:{project_path}" +os.environ[ + 'PYTHONPATH'] = f"{os.environ.get('PYTHONPATH', '.')}:{project_path}" +sys.path.append(project_path) + +import gradio as gr # noqa +from mmengine.utils import mkdir_or_exist # noqa +from process_video import VideoProcessor # noqa + + +def process_video( + teacher_video: Optional[str] = None, + student_video: Optional[str] = None, +): + print(teacher_video) + print(student_video) + + video_processor = VideoProcessor() + if student_video is None and teacher_video is not None: + # Pre-process the teacher video when users record the student video + # using a webcam. This allows users to view the teacher video and + # follow the dance moves while recording the student video. + _ = video_processor.get_keypoints_from_video(teacher_video) + return teacher_video + elif teacher_video is None and student_video is not None: + _ = video_processor.get_keypoints_from_video(student_video) + return student_video + elif teacher_video is None and student_video is None: + return None + + return video_processor.run(teacher_video, student_video) + + +# download video resources +mkdir_or_exist(os.path.join(project_path, 'resources')) +os.system( + f'wget -O {project_path}/resources/tom.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tom.mp4' # noqa +) +os.system( + f'wget -O {project_path}/resources/idol_producer.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/idol_producer.mp4' # noqa +) +os.system( + f'wget -O {project_path}/resources/tsinghua_30fps.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tsinghua_30fps.mp4' # noqa +) + +with gr.Blocks() as demo: + with gr.Tab('Upload-Video'): + with gr.Row(): + with gr.Column(): + gr.Markdown('Student Video') + student_video = gr.Video(type='mp4') + gr.Examples([ + os.path.join(project_path, 'resources/tom.mp4'), + os.path.join(project_path, 'resources/tsinghua_30fps.mp4') + ], student_video) + with gr.Column(): + gr.Markdown('Teacher Video') + teacher_video = gr.Video(type='mp4') + gr.Examples([ + os.path.join(project_path, 'resources/idol_producer.mp4') + ], teacher_video) + + button = gr.Button('Grading', variant='primary') + gr.Markdown('## Display') + out_video = gr.Video() + + button.click( + partial(process_video), [teacher_video, student_video], out_video) + + with gr.Tab('Webcam-Video'): + with gr.Row(): + with gr.Column(): + gr.Markdown('Student Video') + student_video = gr.Video(source='webcam', type='mp4') + with gr.Column(): + gr.Markdown('Teacher Video') + teacher_video = gr.Video(type='mp4') + gr.Examples([ + os.path.join(project_path, 'resources/idol_producer.mp4') + ], teacher_video) + button_upload = gr.Button('Upload', variant='primary') + + button = gr.Button('Grading', variant='primary') + gr.Markdown('## Display') + out_video = gr.Video() + + button_upload.click( + partial(process_video), [teacher_video, student_video], out_video) + button.click( + partial(process_video), [teacher_video, student_video], out_video) + +gr.close_all() +demo.queue() +demo.launch() diff --git a/projects/just_dance/calculate_similarity.py b/projects/just_dance/calculate_similarity.py new file mode 100644 index 0000000000..0465dbffaa --- /dev/null +++ b/projects/just_dance/calculate_similarity.py @@ -0,0 +1,105 @@ +import numpy as np +import torch + +flip_indices = np.array( + [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]) +valid_indices = np.array([0] + list(range(5, 17))) + + +@torch.no_grad() +def _calculate_similarity(tch_kpts: np.ndarray, stu_kpts: np.ndarray): + + stu_kpts = torch.from_numpy(stu_kpts[:, None, valid_indices]) + tch_kpts = torch.from_numpy(tch_kpts[None, :, valid_indices]) + stu_kpts = stu_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1], + stu_kpts.shape[2], 3) + tch_kpts = tch_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1], + stu_kpts.shape[2], 3) + + matrix = torch.stack((stu_kpts, tch_kpts), dim=4) + if torch.cuda.is_available(): + matrix = matrix.cuda() + mask = torch.logical_and(matrix[:, :, :, 2, 0] > 0.3, + matrix[:, :, :, 2, 1] > 0.3) + matrix[~mask] = 0.0 + + matrix_ = matrix.clone() + matrix_[matrix == 0] = 256 + x_min = matrix_.narrow(3, 0, 1).min(dim=2).values + y_min = matrix_.narrow(3, 1, 1).min(dim=2).values + matrix_ = matrix.clone() + # matrix_[matrix == 0] = 0 + x_max = matrix_.narrow(3, 0, 1).max(dim=2).values + y_max = matrix_.narrow(3, 1, 1).max(dim=2).values + + matrix_ = matrix.clone() + matrix_[:, :, :, 0] = (matrix_[:, :, :, 0] - x_min) / ( + x_max - x_min + 1e-4) + matrix_[:, :, :, 1] = (matrix_[:, :, :, 1] - y_min) / ( + y_max - y_min + 1e-4) + matrix_[:, :, :, 2] = (matrix_[:, :, :, 2] > 0.3).float() + xy_dist = matrix_[..., :2, 0] - matrix_[..., :2, 1] + score = matrix_[..., 2, 0] * matrix_[..., 2, 1] + + similarity = (torch.exp(-50 * xy_dist.pow(2).sum(dim=-1)) * + score).sum(dim=-1) / ( + score.sum(dim=-1) + 1e-6) + num_visible_kpts = score.sum(dim=-1) + similarity = similarity * torch.log( + (1 + (num_visible_kpts - 1) * 10).clamp(min=1)) / np.log(161) + + similarity[similarity.isnan()] = 0 + + return similarity + + +@torch.no_grad() +def calculate_similarity(tch_kpts: np.ndarray, stu_kpts: np.ndarray): + assert tch_kpts.shape[1] == 17 + assert tch_kpts.shape[2] == 3 + assert stu_kpts.shape[1] == 17 + assert stu_kpts.shape[2] == 3 + + similarity1 = _calculate_similarity(tch_kpts, stu_kpts) + + stu_kpts_flip = stu_kpts[:, flip_indices] + stu_kpts_flip[..., 0] = 191.5 - stu_kpts_flip[..., 0] + similarity2 = _calculate_similarity(tch_kpts, stu_kpts_flip) + + similarity = torch.stack((similarity1, similarity2)).max(dim=0).values + + return similarity + + +@torch.no_grad() +def select_piece_from_similarity(similarity): + m, n = similarity.size() + row_indices = torch.arange(m).view(-1, 1).expand(m, n).to(similarity) + col_indices = torch.arange(n).view(1, -1).expand(m, n).to(similarity) + diagonal_indices = similarity.size(0) - 1 - row_indices + col_indices + unique_diagonal_indices, inverse_indices = torch.unique( + diagonal_indices, return_inverse=True) + + diagonal_sums_list = torch.zeros( + unique_diagonal_indices.size(0), + dtype=similarity.dtype, + device=similarity.device) + diagonal_sums_list.scatter_add_(0, inverse_indices.view(-1), + similarity.view(-1)) + diagonal_sums_list[:min(m, n) // 4] = 0 + diagonal_sums_list[-min(m, n) // 4:] = 0 + index = diagonal_sums_list.argmax().item() + + similarity_smooth = torch.nn.functional.max_pool2d( + similarity[None], (1, 11), stride=(1, 1), padding=(0, 5))[0] + similarity_vec = similarity_smooth.diagonal(offset=index - m + + 1).cpu().numpy() + + stu_start = max(0, m - 1 - index) + tch_start = max(0, index - m + 1) + + return dict( + stu_start=stu_start, + tch_start=tch_start, + length=len(similarity_vec), + similarity=similarity_vec) diff --git a/projects/just_dance/configs/_base_ b/projects/just_dance/configs/_base_ new file mode 120000 index 0000000000..3bd06d44a7 --- /dev/null +++ b/projects/just_dance/configs/_base_ @@ -0,0 +1 @@ +../../../configs/_base_ \ No newline at end of file diff --git a/projects/just_dance/configs/rtmdet-nano_one-person.py b/projects/just_dance/configs/rtmdet-nano_one-person.py new file mode 100644 index 0000000000..a838522918 --- /dev/null +++ b/projects/just_dance/configs/rtmdet-nano_one-person.py @@ -0,0 +1,3 @@ +_base_ = '../../rtmpose/rtmdet/person/rtmdet_nano_320-8xb32_coco-person.py' + +model = dict(test_cfg=dict(nms_pre=1, score_thr=0.0, max_per_img=1)) diff --git a/projects/just_dance/just_dance_demo.ipynb b/projects/just_dance/just_dance_demo.ipynb new file mode 100644 index 0000000000..45a16e4b8c --- /dev/null +++ b/projects/just_dance/just_dance_demo.ipynb @@ -0,0 +1,712 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6d999c38-2087-4250-b6a4-a30cf8b44ec0", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T13:11:38.997916Z", + "iopub.status.busy": "2023-07-05T13:11:38.997587Z", + "iopub.status.idle": "2023-07-05T13:11:39.001928Z", + "shell.execute_reply": "2023-07-05T13:11:39.001429Z", + "shell.execute_reply.started": "2023-07-05T13:11:38.997898Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import os.path as osp\n", + "import torch\n", + "import numpy as np\n", + "import mmcv\n", + "import cv2\n", + "from mmengine.utils import track_iter_progress" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfa9bf9b-dc2c-4803-a034-8ae8778113e0", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T12:42:15.884465Z", + "iopub.status.busy": "2023-07-05T12:42:15.884167Z", + "iopub.status.idle": "2023-07-05T12:42:19.774569Z", + "shell.execute_reply": "2023-07-05T12:42:19.774020Z", + "shell.execute_reply.started": "2023-07-05T12:42:15.884448Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# download example videos\n", + "from mmengine.utils import mkdir_or_exist\n", + "mkdir_or_exist('resources')\n", + "! wget -O resources/student_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tom.mp4 \n", + "! wget -O resources/teacher_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/idol_producer.mp4 \n", + "# ! wget -O resources/student_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tsinghua_30fps.mp4 \n", + "\n", + "student_video = 'resources/student_video.mp4'\n", + "teacher_video = 'resources/teacher_video.mp4'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "652b6b91-e1c0-461b-90e5-653bc35ec380", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T12:42:20.693931Z", + "iopub.status.busy": "2023-07-05T12:42:20.693353Z", + "iopub.status.idle": "2023-07-05T12:43:14.533985Z", + "shell.execute_reply": "2023-07-05T12:43:14.533431Z", + "shell.execute_reply.started": "2023-07-05T12:42:20.693910Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# convert the fps of videos to 30\n", + "from mmcv import VideoReader\n", + "\n", + "if VideoReader(student_video) != 30:\n", + " # ffmpeg is required to convert the video fps\n", + " # which can be installed via `sudo apt install ffmpeg` on ubuntu\n", + " student_video_30fps = student_video.replace(\n", + " f\".{student_video.rsplit('.', 1)[1]}\",\n", + " f\"_30fps.{student_video.rsplit('.', 1)[1]}\"\n", + " )\n", + " !ffmpeg -i {student_video} -vf \"minterpolate='fps=30'\" {student_video_30fps}\n", + " student_video = student_video_30fps\n", + " \n", + "if VideoReader(teacher_video) != 30:\n", + " teacher_video_30fps = teacher_video.replace(\n", + " f\".{teacher_video.rsplit('.', 1)[1]}\",\n", + " f\"_30fps.{teacher_video.rsplit('.', 1)[1]}\"\n", + " )\n", + " !ffmpeg -i {teacher_video} -vf \"minterpolate='fps=30'\" {teacher_video_30fps}\n", + " teacher_video = teacher_video_30fps " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a4e141d-ee4a-4e06-a380-230418c9b936", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T12:45:01.672054Z", + "iopub.status.busy": "2023-07-05T12:45:01.671727Z", + "iopub.status.idle": "2023-07-05T12:45:02.417026Z", + "shell.execute_reply": "2023-07-05T12:45:02.416567Z", + "shell.execute_reply.started": "2023-07-05T12:45:01.672035Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# init pose estimator\n", + "from mmpose.apis.inferencers import Pose2DInferencer\n", + "pose_estimator = Pose2DInferencer(\n", + " 'rtmpose-t_8xb256-420e_aic-coco-256x192',\n", + " det_model='configs/rtmdet-nano_one-person.py',\n", + " det_weights='https://download.openmmlab.com/mmpose/v1/projects/' \n", + " 'rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth'\n", + ")\n", + "pose_estimator.model.test_cfg['flip_test'] = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "879ba5c0-4d2d-4cca-92d7-d4f94e04a821", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T12:45:05.192437Z", + "iopub.status.busy": "2023-07-05T12:45:05.191982Z", + "iopub.status.idle": "2023-07-05T12:45:05.197379Z", + "shell.execute_reply": "2023-07-05T12:45:05.196780Z", + "shell.execute_reply.started": "2023-07-05T12:45:05.192417Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "@torch.no_grad()\n", + "def get_keypoints_from_frame(image, pose_estimator):\n", + " \"\"\"Extract keypoints from a single video frame.\"\"\"\n", + "\n", + " det_results = pose_estimator.detector(\n", + " image, return_datasample=True)['predictions']\n", + " pred_instance = det_results[0].pred_instances.numpy()\n", + "\n", + " if len(pred_instance) == 0 or pred_instance.scores[0] < 0.2:\n", + " return np.zeros((1, 17, 3), dtype=np.float32)\n", + "\n", + " data_info = dict(\n", + " img=image,\n", + " bbox=pred_instance.bboxes[:1],\n", + " bbox_score=pred_instance.scores[:1])\n", + "\n", + " data_info.update(pose_estimator.model.dataset_meta)\n", + " data = pose_estimator.collate_fn(\n", + " [pose_estimator.pipeline(data_info)])\n", + "\n", + " # custom forward\n", + " data = pose_estimator.model.data_preprocessor(data, False)\n", + " feats = pose_estimator.model.extract_feat(data['inputs'])\n", + " pred_instances = pose_estimator.model.head.predict(\n", + " feats,\n", + " data['data_samples'],\n", + " test_cfg=pose_estimator.model.test_cfg)[0]\n", + " keypoints = np.concatenate(\n", + " (pred_instances.keypoints, pred_instances.keypoint_scores[...,\n", + " None]),\n", + " axis=-1)\n", + "\n", + " return keypoints " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31e5bd4c-4c2b-4fe0-b64c-1afed67b7688", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T12:47:55.564788Z", + "iopub.status.busy": "2023-07-05T12:47:55.564450Z", + "iopub.status.idle": "2023-07-05T12:49:37.222662Z", + "shell.execute_reply": "2023-07-05T12:49:37.222028Z", + "shell.execute_reply.started": "2023-07-05T12:47:55.564770Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# pose estimation in two videos\n", + "student_poses, teacher_poses = [], []\n", + "for frame in VideoReader(student_video):\n", + " student_poses.append(get_keypoints_from_frame(frame, pose_estimator))\n", + "for frame in VideoReader(teacher_video):\n", + " teacher_poses.append(get_keypoints_from_frame(frame, pose_estimator))\n", + " \n", + "student_poses = np.concatenate(student_poses)\n", + "teacher_poses = np.concatenate(teacher_poses)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38a8d7a5-17ed-4ce2-bb8b-d1637cb49578", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T12:55:09.342432Z", + "iopub.status.busy": "2023-07-05T12:55:09.342185Z", + "iopub.status.idle": "2023-07-05T12:55:09.350522Z", + "shell.execute_reply": "2023-07-05T12:55:09.350099Z", + "shell.execute_reply.started": "2023-07-05T12:55:09.342416Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "valid_indices = np.array([0] + list(range(5, 17)))\n", + "\n", + "@torch.no_grad()\n", + "def _calculate_similarity(tch_kpts: np.ndarray, stu_kpts: np.ndarray):\n", + "\n", + " stu_kpts = torch.from_numpy(stu_kpts[:, None, valid_indices])\n", + " tch_kpts = torch.from_numpy(tch_kpts[None, :, valid_indices])\n", + " stu_kpts = stu_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1],\n", + " stu_kpts.shape[2], 3)\n", + " tch_kpts = tch_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1],\n", + " stu_kpts.shape[2], 3)\n", + "\n", + " matrix = torch.stack((stu_kpts, tch_kpts), dim=4)\n", + " if torch.cuda.is_available():\n", + " matrix = matrix.cuda()\n", + " # only consider visible keypoints\n", + " mask = torch.logical_and(matrix[:, :, :, 2, 0] > 0.3,\n", + " matrix[:, :, :, 2, 1] > 0.3)\n", + " matrix[~mask] = 0.0\n", + "\n", + " matrix_ = matrix.clone()\n", + " matrix_[matrix == 0] = 256\n", + " x_min = matrix_.narrow(3, 0, 1).min(dim=2).values\n", + " y_min = matrix_.narrow(3, 1, 1).min(dim=2).values\n", + " matrix_ = matrix.clone()\n", + " x_max = matrix_.narrow(3, 0, 1).max(dim=2).values\n", + " y_max = matrix_.narrow(3, 1, 1).max(dim=2).values\n", + "\n", + " matrix_ = matrix.clone()\n", + " matrix_[:, :, :, 0] = (matrix_[:, :, :, 0] - x_min) / (\n", + " x_max - x_min + 1e-4)\n", + " matrix_[:, :, :, 1] = (matrix_[:, :, :, 1] - y_min) / (\n", + " y_max - y_min + 1e-4)\n", + " matrix_[:, :, :, 2] = (matrix_[:, :, :, 2] > 0.3).float()\n", + " xy_dist = matrix_[..., :2, 0] - matrix_[..., :2, 1]\n", + " score = matrix_[..., 2, 0] * matrix_[..., 2, 1]\n", + "\n", + " similarity = (torch.exp(-50 * xy_dist.pow(2).sum(dim=-1)) *\n", + " score).sum(dim=-1) / (\n", + " score.sum(dim=-1) + 1e-6)\n", + " num_visible_kpts = score.sum(dim=-1)\n", + " similarity = similarity * torch.log(\n", + " (1 + (num_visible_kpts - 1) * 10).clamp(min=1)) / np.log(161)\n", + "\n", + " similarity[similarity.isnan()] = 0\n", + "\n", + " return similarity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "658bcf89-df06-4c73-9323-8973a49c14c3", + "metadata": { + "execution": { + "iopub.execute_input": "2023-07-05T12:55:31.978675Z", + "iopub.status.busy": "2023-07-05T12:55:31.978219Z", + "iopub.status.idle": "2023-07-05T12:55:32.149624Z", + "shell.execute_reply": "2023-07-05T12:55:32.148568Z", + "shell.execute_reply.started": "2023-07-05T12:55:31.978657Z" + } + }, + "outputs": [], + "source": [ + "# compute similarity without flip\n", + "similarity1 = _calculate_similarity(teacher_poses, student_poses)\n", + "\n", + "# compute similarity with flip\n", + "flip_indices = np.array(\n", + " [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15])\n", + "student_poses_flip = student_poses[:, flip_indices]\n", + "student_poses_flip[..., 0] = 191.5 - student_poses_flip[..., 0]\n", + "similarity2 = _calculate_similarity(teacher_poses, student_poses_flip)\n", + "\n", + "# select the larger similarity\n", + "similarity = torch.stack((similarity1, similarity2)).max(dim=0).values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f981410d-4585-47c1-98c0-6946f948487d", + "metadata": { + "ExecutionIndicator": { + "show": false + }, + "execution": { + "iopub.execute_input": "2023-07-05T12:55:57.321845Z", + "iopub.status.busy": "2023-07-05T12:55:57.321530Z", + "iopub.status.idle": "2023-07-05T12:55:57.582879Z", + "shell.execute_reply": "2023-07-05T12:55:57.582425Z", + "shell.execute_reply.started": "2023-07-05T12:55:57.321826Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# visualize the similarity\n", + "plt.imshow(similarity.cpu().numpy())\n", + "\n", + "# there is an apparent diagonal in the figure\n", + "# we can select matched video snippets with this diagonal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13c189e5-fc53-46a2-9057-f0f2ffc1f46d", + "metadata": { + "execution": { + "iopub.execute_input": "2023-07-05T12:58:16.913855Z", + "iopub.status.busy": "2023-07-05T12:58:16.913529Z", + "iopub.status.idle": "2023-07-05T12:58:16.919972Z", + "shell.execute_reply": "2023-07-05T12:58:16.919005Z", + "shell.execute_reply.started": "2023-07-05T12:58:16.913837Z" + } + }, + "outputs": [], + "source": [ + "@torch.no_grad()\n", + "def select_piece_from_similarity(similarity):\n", + " m, n = similarity.size()\n", + " row_indices = torch.arange(m).view(-1, 1).expand(m, n).to(similarity)\n", + " col_indices = torch.arange(n).view(1, -1).expand(m, n).to(similarity)\n", + " diagonal_indices = similarity.size(0) - 1 - row_indices + col_indices\n", + " unique_diagonal_indices, inverse_indices = torch.unique(\n", + " diagonal_indices, return_inverse=True)\n", + "\n", + " diagonal_sums_list = torch.zeros(\n", + " unique_diagonal_indices.size(0),\n", + " dtype=similarity.dtype,\n", + " device=similarity.device)\n", + " diagonal_sums_list.scatter_add_(0, inverse_indices.view(-1),\n", + " similarity.view(-1))\n", + " diagonal_sums_list[:min(m, n) // 4] = 0\n", + " diagonal_sums_list[-min(m, n) // 4:] = 0\n", + " index = diagonal_sums_list.argmax().item()\n", + "\n", + " similarity_smooth = torch.nn.functional.max_pool2d(\n", + " similarity[None], (1, 11), stride=(1, 1), padding=(0, 5))[0]\n", + " similarity_vec = similarity_smooth.diagonal(offset=index - m +\n", + " 1).cpu().numpy()\n", + "\n", + " stu_start = max(0, m - 1 - index)\n", + " tch_start = max(0, index - m + 1)\n", + "\n", + " return dict(\n", + " stu_start=stu_start,\n", + " tch_start=tch_start,\n", + " length=len(similarity_vec),\n", + " similarity=similarity_vec)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c0e19df-949d-471d-804d-409b3b9ddf7d", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T12:58:44.860190Z", + "iopub.status.busy": "2023-07-05T12:58:44.859878Z", + "iopub.status.idle": "2023-07-05T12:58:44.888465Z", + "shell.execute_reply": "2023-07-05T12:58:44.887917Z", + "shell.execute_reply.started": "2023-07-05T12:58:44.860173Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "matched_piece_info = select_piece_from_similarity(similarity)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51b0a2bd-253c-4a8f-a82a-263e18a4703e", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T13:01:19.061408Z", + "iopub.status.busy": "2023-07-05T13:01:19.060857Z", + "iopub.status.idle": "2023-07-05T13:01:19.293742Z", + "shell.execute_reply": "2023-07-05T13:01:19.293298Z", + "shell.execute_reply.started": "2023-07-05T13:01:19.061378Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "plt.imshow(similarity.cpu().numpy())\n", + "plt.plot((matched_piece_info['tch_start'], \n", + " matched_piece_info['tch_start']+matched_piece_info['length']-1),\n", + " (matched_piece_info['stu_start'],\n", + " matched_piece_info['stu_start']+matched_piece_info['length']-1), 'r')" + ] + }, + { + "cell_type": "markdown", + "id": "ffcde4e7-ff50-483a-b515-604c1d8f121a", + "metadata": {}, + "source": [ + "# Generate Output Video" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72171a0c-ab33-45bb-b84c-b15f0816ed3a", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T13:11:50.063595Z", + "iopub.status.busy": "2023-07-05T13:11:50.063259Z", + "iopub.status.idle": "2023-07-05T13:11:50.070929Z", + "shell.execute_reply": "2023-07-05T13:11:50.070411Z", + "shell.execute_reply.started": "2023-07-05T13:11:50.063574Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from typing import Tuple\n", + "\n", + "def resize_image_to_fixed_height(image: np.ndarray,\n", + " fixed_height: int) -> np.ndarray:\n", + " \"\"\"Resizes an input image to a specified fixed height while maintaining its\n", + " aspect ratio.\n", + "\n", + " Args:\n", + " image (np.ndarray): Input image as a numpy array [H, W, C]\n", + " fixed_height (int): Desired fixed height of the output image.\n", + "\n", + " Returns:\n", + " Resized image as a numpy array (fixed_height, new_width, channels).\n", + " \"\"\"\n", + " original_height, original_width = image.shape[:2]\n", + "\n", + " scale_ratio = fixed_height / original_height\n", + " new_width = int(original_width * scale_ratio)\n", + " resized_image = cv2.resize(image, (new_width, fixed_height))\n", + "\n", + " return resized_image\n", + "\n", + "def blend_images(img1: np.ndarray,\n", + " img2: np.ndarray,\n", + " blend_ratios: Tuple[float, float] = (1, 1)) -> np.ndarray:\n", + " \"\"\"Blends two input images with specified blend ratios.\n", + "\n", + " Args:\n", + " img1 (np.ndarray): First input image as a numpy array [H, W, C].\n", + " img2 (np.ndarray): Second input image as a numpy array [H, W, C]\n", + " blend_ratios (tuple): A tuple of two floats representing the blend\n", + " ratios for the two input images.\n", + "\n", + " Returns:\n", + " Blended image as a numpy array [H, W, C]\n", + " \"\"\"\n", + "\n", + " def normalize_image(image: np.ndarray) -> np.ndarray:\n", + " if image.dtype == np.uint8:\n", + " return image.astype(np.float32) / 255.0\n", + " return image\n", + "\n", + " img1 = normalize_image(img1)\n", + " img2 = normalize_image(img2)\n", + "\n", + " blended_image = img1 * blend_ratios[0] + img2 * blend_ratios[1]\n", + " blended_image = blended_image.clip(min=0, max=1)\n", + " blended_image = (blended_image * 255).astype(np.uint8)\n", + "\n", + " return blended_image\n", + "\n", + "def get_smoothed_kpt(kpts, index, sigma=5):\n", + " \"\"\"Smooths keypoints using a Gaussian filter.\"\"\"\n", + " assert kpts.shape[1] == 17\n", + " assert kpts.shape[2] == 3\n", + " assert sigma % 2 == 1\n", + "\n", + " num_kpts = len(kpts)\n", + "\n", + " start_idx = max(0, index - sigma // 2)\n", + " end_idx = min(num_kpts, index + sigma // 2 + 1)\n", + "\n", + " # Extract a piece of the keypoints array to apply the filter\n", + " piece = kpts[start_idx:end_idx].copy()\n", + " original_kpt = kpts[index]\n", + "\n", + " # Split the piece into coordinates and scores\n", + " coords, scores = piece[..., :2], piece[..., 2]\n", + "\n", + " # Calculate the Gaussian ratio for each keypoint\n", + " gaussian_ratio = np.arange(len(scores)) + start_idx - index\n", + " gaussian_ratio = np.exp(-gaussian_ratio**2 / 2)\n", + "\n", + " # Update scores using the Gaussian ratio\n", + " scores *= gaussian_ratio[:, None]\n", + "\n", + " # Compute the smoothed coordinates\n", + " smoothed_coords = (coords * scores[..., None]).sum(axis=0) / (\n", + " scores[..., None].sum(axis=0) + 1e-4)\n", + "\n", + " original_kpt[..., :2] = smoothed_coords\n", + "\n", + " return original_kpt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "609b5adc-e176-4bf9-b9a4-506f72440017", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T13:12:46.198835Z", + "iopub.status.busy": "2023-07-05T13:12:46.198268Z", + "iopub.status.idle": "2023-07-05T13:12:46.202273Z", + "shell.execute_reply": "2023-07-05T13:12:46.200881Z", + "shell.execute_reply.started": "2023-07-05T13:12:46.198815Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "score, last_vis_score = 0, 0\n", + "video_writer = None\n", + "output_file = 'output.mp4'\n", + "stu_kpts = student_poses\n", + "tch_kpts = teacher_poses" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a264405a-5d50-49de-8637-2d1f67cb0a70", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-07-05T13:13:11.334760Z", + "iopub.status.busy": "2023-07-05T13:13:11.334433Z", + "iopub.status.idle": "2023-07-05T13:13:17.264181Z", + "shell.execute_reply": "2023-07-05T13:13:17.262931Z", + "shell.execute_reply.started": "2023-07-05T13:13:11.334742Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from mmengine.structures import InstanceData\n", + "\n", + "tch_video_reader = VideoReader(teacher_video)\n", + "stu_video_reader = VideoReader(student_video)\n", + "for _ in range(matched_piece_info['tch_start']):\n", + " _ = next(tch_video_reader)\n", + "for _ in range(matched_piece_info['stu_start']):\n", + " _ = next(stu_video_reader)\n", + " \n", + "for i in track_iter_progress(range(matched_piece_info['length'])):\n", + " tch_frame = mmcv.bgr2rgb(next(tch_video_reader))\n", + " stu_frame = mmcv.bgr2rgb(next(stu_video_reader))\n", + " tch_frame = resize_image_to_fixed_height(tch_frame, 300)\n", + " stu_frame = resize_image_to_fixed_height(stu_frame, 300)\n", + "\n", + " stu_kpt = get_smoothed_kpt(stu_kpts, matched_piece_info['stu_start'] + i,\n", + " 5)\n", + " tch_kpt = get_smoothed_kpt(tch_kpts, matched_piece_info['tch_start'] + i,\n", + " 5)\n", + "\n", + " # draw pose\n", + " stu_kpt[..., 1] += (300 - 256)\n", + " tch_kpt[..., 0] += (256 - 192)\n", + " tch_kpt[..., 1] += (300 - 256)\n", + " stu_inst = InstanceData(\n", + " keypoints=stu_kpt[None, :, :2],\n", + " keypoint_scores=stu_kpt[None, :, 2])\n", + " tch_inst = InstanceData(\n", + " keypoints=tch_kpt[None, :, :2],\n", + " keypoint_scores=tch_kpt[None, :, 2])\n", + " \n", + " stu_out_img = pose_estimator.visualizer._draw_instances_kpts(\n", + " np.zeros((300, 256, 3)), stu_inst)\n", + " tch_out_img = pose_estimator.visualizer._draw_instances_kpts(\n", + " np.zeros((300, 256, 3)), tch_inst)\n", + " out_img = blend_images(\n", + " stu_out_img, tch_out_img, blend_ratios=(1, 0.3))\n", + "\n", + " # draw score\n", + " score_frame = matched_piece_info['similarity'][i]\n", + " score += score_frame * 1000\n", + " if score - last_vis_score > 1500:\n", + " last_vis_score = score\n", + " pose_estimator.visualizer.set_image(out_img)\n", + " pose_estimator.visualizer.draw_texts(\n", + " 'score: ', (60, 30),\n", + " font_sizes=15,\n", + " colors=(255, 255, 255),\n", + " vertical_alignments='bottom')\n", + " pose_estimator.visualizer.draw_texts(\n", + " f'{int(last_vis_score)}', (115, 30),\n", + " font_sizes=30 * max(0.4, score_frame),\n", + " colors=(255, 255, 255),\n", + " vertical_alignments='bottom')\n", + " out_img = pose_estimator.visualizer.get_image() \n", + " \n", + " # concatenate\n", + " concatenated_image = np.hstack((stu_frame, out_img, tch_frame))\n", + " if video_writer is None:\n", + " video_writer = cv2.VideoWriter(output_file,\n", + " cv2.VideoWriter_fourcc(*'mp4v'),\n", + " 30,\n", + " (concatenated_image.shape[1],\n", + " concatenated_image.shape[0]))\n", + " video_writer.write(mmcv.rgb2bgr(concatenated_image))\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "745fdd75-6ed4-4cae-9f21-c2cd486ee918", + "metadata": { + "execution": { + "iopub.execute_input": "2023-07-05T13:13:18.704492Z", + "iopub.status.busy": "2023-07-05T13:13:18.704179Z", + "iopub.status.idle": "2023-07-05T13:13:18.714843Z", + "shell.execute_reply": "2023-07-05T13:13:18.713866Z", + "shell.execute_reply.started": "2023-07-05T13:13:18.704472Z" + } + }, + "outputs": [], + "source": [ + "if video_writer is not None:\n", + " video_writer.release() " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cb0bc99-ca19-44f1-bc0a-38e14afa980f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/projects/just_dance/process_video.py b/projects/just_dance/process_video.py new file mode 100644 index 0000000000..7f1d48b922 --- /dev/null +++ b/projects/just_dance/process_video.py @@ -0,0 +1,229 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import tempfile + +import cv2 +import mmcv +import numpy as np +import torch +from mmengine.structures import InstanceData +from mmengine.utils import track_iter_progress + +from mmpose.apis import Pose2DInferencer +from mmpose.datasets.datasets.utils import parse_pose_metainfo +from mmpose.visualization import PoseLocalVisualizer + +try: + from .calculate_similarity import (calculate_similarity, + select_piece_from_similarity) + from .utils import (blend_images, convert_video_fps, get_smoothed_kpt, + resize_image_to_fixed_height) +except ImportError: + from calculate_similarity import (calculate_similarity, + select_piece_from_similarity) + from utils import (blend_images, convert_video_fps, get_smoothed_kpt, + resize_image_to_fixed_height) + +det_config = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'configs/rtmdet-nano_one-person.py') +det_weights = 'https://download.openmmlab.com/mmpose/v1/projects/' \ + 'rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth' + + +class VideoProcessor: + """A class to process videos for pose estimation and visualization.""" + + @property + def pose_estimator(self) -> Pose2DInferencer: + if not hasattr(self, '_pose_estimator'): + self._pose_estimator = Pose2DInferencer( + 'rtmpose-t_8xb256-420e_aic-coco-256x192', + det_model=det_config, + det_weights=det_weights) + self._pose_estimator.model.test_cfg['flip_test'] = False + return self._pose_estimator + + @property + def visualizer(self) -> PoseLocalVisualizer: + if hasattr(self, '_visualizer'): + return self._visualizer + elif hasattr(self, '_pose_estimator'): + return self._pose_estimator.visualizer + + # init visualizer + self._visualizer = PoseLocalVisualizer() + metainfo_file = os.path.join( + os.path.dirname(os.path.abspath(__file__)).rsplit(os.sep, 1)[0], + 'configs/_base_/datasets/coco.py') + metainfo = parse_pose_metainfo(dict(from_file=metainfo_file)) + self._visualizer.set_dataset_meta(metainfo) + return self._visualizer + + @torch.no_grad() + def get_keypoints_from_frame(self, image: np.ndarray) -> np.ndarray: + """Extract keypoints from a single video frame.""" + + det_results = self.pose_estimator.detector( + image, return_datasample=True)['predictions'] + pred_instance = det_results[0].pred_instances + + if len(pred_instance) == 0: + return np.zeros((1, 17, 3), dtype=np.float32) + + # only select the most significant person + data_info = dict( + img=image, + bbox=pred_instance.bboxes.cpu().numpy()[:1], + bbox_score=pred_instance.scores.cpu().numpy()[:1]) + + if data_info['bbox_score'] < 0.2: + return np.zeros((1, 17, 3), dtype=np.float32) + + data_info.update(self.pose_estimator.model.dataset_meta) + data = self.pose_estimator.collate_fn( + [self.pose_estimator.pipeline(data_info)]) + + # custom forward + data = self.pose_estimator.model.data_preprocessor(data, False) + feats = self.pose_estimator.model.extract_feat(data['inputs']) + pred_instances = self.pose_estimator.model.head.predict( + feats, + data['data_samples'], + test_cfg=self.pose_estimator.model.test_cfg)[0] + keypoints = np.concatenate( + (pred_instances.keypoints, pred_instances.keypoint_scores[..., + None]), + axis=-1) + + return keypoints + + @torch.no_grad() + def get_keypoints_from_video(self, video: str) -> np.ndarray: + """Extract keypoints from a video.""" + + video_fname = video.rsplit('.', 1)[0] + if os.path.exists(f'{video_fname}_kpts.pth'): + keypoints = torch.load(f'{video_fname}_kpts.pth') + return keypoints + + video_reader = mmcv.VideoReader(video) + + if video_reader.fps != 30: + video_reader = mmcv.VideoReader(convert_video_fps(video)) + + assert video_reader.fps == 30, f'only support videos with 30 FPS, ' \ + f'but the video {video_fname} has {video_reader.fps} fps' + keypoints_list = [] + for i, frame in enumerate(video_reader): + keypoints = self.get_keypoints_from_frame(frame) + keypoints_list.append(keypoints) + keypoints = np.concatenate(keypoints_list) + torch.save(keypoints, f'{video_fname}_kpts.pth') + return keypoints + + @torch.no_grad() + def run(self, tch_video: str, stu_video: str): + # extract human poses + tch_kpts = self.get_keypoints_from_video(tch_video) + stu_kpts = self.get_keypoints_from_video(stu_video) + + # compute similarity + similarity = calculate_similarity(tch_kpts, stu_kpts) + + # select piece + piece_info = select_piece_from_similarity(similarity) + + # output + tch_name = os.path.basename(tch_video).rsplit('.', 1)[0] + stu_name = os.path.basename(stu_video).rsplit('.', 1)[0] + fname = f'{tch_name}-{stu_name}.mp4' + output_file = os.path.join(tempfile.mkdtemp(), fname) + return self.generate_output_video(tch_video, stu_video, output_file, + tch_kpts, stu_kpts, piece_info) + + def generate_output_video(self, tch_video: str, stu_video: str, + output_file: str, tch_kpts: np.ndarray, + stu_kpts: np.ndarray, piece_info: dict) -> str: + """Generate an output video with keypoints overlay.""" + + tch_video_reader = mmcv.VideoReader(tch_video) + stu_video_reader = mmcv.VideoReader(stu_video) + for _ in range(piece_info['tch_start']): + _ = next(tch_video_reader) + for _ in range(piece_info['stu_start']): + _ = next(stu_video_reader) + + score, last_vis_score = 0, 0 + video_writer = None + for i in track_iter_progress(range(piece_info['length'])): + tch_frame = mmcv.bgr2rgb(next(tch_video_reader)) + stu_frame = mmcv.bgr2rgb(next(stu_video_reader)) + tch_frame = resize_image_to_fixed_height(tch_frame, 300) + stu_frame = resize_image_to_fixed_height(stu_frame, 300) + + stu_kpt = get_smoothed_kpt(stu_kpts, piece_info['stu_start'] + i, + 5) + tch_kpt = get_smoothed_kpt(tch_kpts, piece_info['tch_start'] + i, + 5) + + # draw pose + stu_kpt[..., 1] += (300 - 256) + tch_kpt[..., 0] += (256 - 192) + tch_kpt[..., 1] += (300 - 256) + stu_inst = InstanceData( + keypoints=stu_kpt[None, :, :2], + keypoint_scores=stu_kpt[None, :, 2]) + tch_inst = InstanceData( + keypoints=tch_kpt[None, :, :2], + keypoint_scores=tch_kpt[None, :, 2]) + + stu_out_img = self.visualizer._draw_instances_kpts( + np.zeros((300, 256, 3)), stu_inst) + tch_out_img = self.visualizer._draw_instances_kpts( + np.zeros((300, 256, 3)), tch_inst) + out_img = blend_images( + stu_out_img, tch_out_img, blend_ratios=(1, 0.3)) + + # draw score + score_frame = piece_info['similarity'][i] + score += score_frame * 1000 + if score - last_vis_score > 1500: + last_vis_score = score + self.visualizer.set_image(out_img) + self.visualizer.draw_texts( + 'score: ', (60, 30), + font_sizes=15, + colors=(255, 255, 255), + vertical_alignments='bottom') + self.visualizer.draw_texts( + f'{int(last_vis_score)}', (115, 30), + font_sizes=30 * max(0.4, score_frame), + colors=(255, 255, 255), + vertical_alignments='bottom') + out_img = self.visualizer.get_image() + + # concatenate + concatenated_image = np.hstack((stu_frame, out_img, tch_frame)) + if video_writer is None: + video_writer = cv2.VideoWriter(output_file, + cv2.VideoWriter_fourcc(*'mp4v'), + 30, + (concatenated_image.shape[1], + concatenated_image.shape[0])) + video_writer.write(mmcv.rgb2bgr(concatenated_image)) + + if video_writer is not None: + video_writer.release() + return output_file + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('teacher_video', help='Path to the Teacher Video') + parser.add_argument('student_video', help='Path to the Student Video') + args = parser.parse_args() + + processor = VideoProcessor() + processor.run(args.teacher_video, args.student_video) diff --git a/projects/just_dance/utils.py b/projects/just_dance/utils.py new file mode 100644 index 0000000000..cd150bb1be --- /dev/null +++ b/projects/just_dance/utils.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Tuple + +import cv2 +import numpy as np + + +def resize_image_to_fixed_height(image: np.ndarray, + fixed_height: int) -> np.ndarray: + """Resizes an input image to a specified fixed height while maintaining its + aspect ratio. + + Args: + image (np.ndarray): Input image as a numpy array [H, W, C] + fixed_height (int): Desired fixed height of the output image. + + Returns: + Resized image as a numpy array (fixed_height, new_width, channels). + """ + original_height, original_width = image.shape[:2] + + scale_ratio = fixed_height / original_height + new_width = int(original_width * scale_ratio) + resized_image = cv2.resize(image, (new_width, fixed_height)) + + return resized_image + + +def blend_images(img1: np.ndarray, + img2: np.ndarray, + blend_ratios: Tuple[float, float] = (1, 1)) -> np.ndarray: + """Blends two input images with specified blend ratios. + + Args: + img1 (np.ndarray): First input image as a numpy array [H, W, C]. + img2 (np.ndarray): Second input image as a numpy array [H, W, C] + blend_ratios (tuple): A tuple of two floats representing the blend + ratios for the two input images. + + Returns: + Blended image as a numpy array [H, W, C] + """ + + def normalize_image(image: np.ndarray) -> np.ndarray: + if image.dtype == np.uint8: + return image.astype(np.float32) / 255.0 + return image + + img1 = normalize_image(img1) + img2 = normalize_image(img2) + + blended_image = img1 * blend_ratios[0] + img2 * blend_ratios[1] + blended_image = blended_image.clip(min=0, max=1) + blended_image = (blended_image * 255).astype(np.uint8) + + return blended_image + + +def convert_video_fps(video): + + input_video = video + video_name, post_fix = input_video.rsplit('.', 1) + output_video = f'{video_name}_30fps.{post_fix}' + if os.path.exists(output_video): + return output_video + + os.system( + f"ffmpeg -i {input_video} -vf \"minterpolate='fps=30'\" {output_video}" + ) + + return output_video + + +def get_smoothed_kpt(kpts, index, sigma=5): + """Smooths keypoints using a Gaussian filter.""" + assert kpts.shape[1] == 17 + assert kpts.shape[2] == 3 + assert sigma % 2 == 1 + + num_kpts = len(kpts) + + start_idx = max(0, index - sigma // 2) + end_idx = min(num_kpts, index + sigma // 2 + 1) + + # Extract a piece of the keypoints array to apply the filter + piece = kpts[start_idx:end_idx].copy() + original_kpt = kpts[index] + + # Split the piece into coordinates and scores + coords, scores = piece[..., :2], piece[..., 2] + + # Calculate the Gaussian ratio for each keypoint + gaussian_ratio = np.arange(len(scores)) + start_idx - index + gaussian_ratio = np.exp(-gaussian_ratio**2 / 2) + + # Update scores using the Gaussian ratio + scores *= gaussian_ratio[:, None] + + # Compute the smoothed coordinates + smoothed_coords = (coords * scores[..., None]).sum(axis=0) / ( + scores[..., None].sum(axis=0) + 1e-4) + + original_kpt[..., :2] = smoothed_coords + + return original_kpt