-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
1,304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
import sys | ||
from functools import partial | ||
from typing import Optional | ||
|
||
project_path = os.path.join(os.path.dirname(os.path.abspath(__file__))) | ||
mmpose_path = project_path.split('/projects', 1)[0] | ||
|
||
os.system('python -m pip install Openmim') | ||
os.system('python -m mim install "mmcv>=2.0.0"') | ||
os.system('python -m mim install mmengine') | ||
os.system('python -m mim install "mmdet>=3.0.0"') | ||
os.system(f'python -m mim install -e {mmpose_path}') | ||
|
||
os.environ['PATH'] = f"{os.environ['PATH']}:{project_path}" | ||
os.environ[ | ||
'PYTHONPATH'] = f"{os.environ.get('PYTHONPATH', '.')}:{project_path}" | ||
sys.path.append(project_path) | ||
|
||
import gradio as gr # noqa | ||
from mmengine.utils import mkdir_or_exist # noqa | ||
from process_video import VideoProcessor # noqa | ||
|
||
|
||
def process_video( | ||
teacher_video: Optional[str] = None, | ||
student_video: Optional[str] = None, | ||
): | ||
print(teacher_video) | ||
print(student_video) | ||
|
||
video_processor = VideoProcessor() | ||
if student_video is None and teacher_video is not None: | ||
# Pre-process the teacher video when users record the student video | ||
# using a webcam. This allows users to view the teacher video and | ||
# follow the dance moves while recording the student video. | ||
_ = video_processor.get_keypoints_from_video(teacher_video) | ||
return teacher_video | ||
elif teacher_video is None and student_video is not None: | ||
_ = video_processor.get_keypoints_from_video(student_video) | ||
return student_video | ||
elif teacher_video is None and student_video is None: | ||
return None | ||
|
||
return video_processor.run(teacher_video, student_video) | ||
|
||
|
||
# download video resources | ||
mkdir_or_exist(os.path.join(project_path, 'resources')) | ||
os.system( | ||
f'wget -O {project_path}/resources/tom.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tom.mp4' # noqa | ||
) | ||
os.system( | ||
f'wget -O {project_path}/resources/idol_producer.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/idol_producer.mp4' # noqa | ||
) | ||
os.system( | ||
f'wget -O {project_path}/resources/tsinghua_30fps.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tsinghua_30fps.mp4' # noqa | ||
) | ||
|
||
with gr.Blocks() as demo: | ||
with gr.Tab('Upload-Video'): | ||
with gr.Row(): | ||
with gr.Column(): | ||
gr.Markdown('Student Video') | ||
student_video = gr.Video(type='mp4') | ||
gr.Examples([ | ||
os.path.join(project_path, 'resources/tom.mp4'), | ||
os.path.join(project_path, 'resources/tsinghua_30fps.mp4') | ||
], student_video) | ||
with gr.Column(): | ||
gr.Markdown('Teacher Video') | ||
teacher_video = gr.Video(type='mp4') | ||
gr.Examples([ | ||
os.path.join(project_path, 'resources/idol_producer.mp4') | ||
], teacher_video) | ||
|
||
button = gr.Button('Grading', variant='primary') | ||
gr.Markdown('## Display') | ||
out_video = gr.Video() | ||
|
||
button.click( | ||
partial(process_video), [teacher_video, student_video], out_video) | ||
|
||
with gr.Tab('Webcam-Video'): | ||
with gr.Row(): | ||
with gr.Column(): | ||
gr.Markdown('Student Video') | ||
student_video = gr.Video(source='webcam', type='mp4') | ||
with gr.Column(): | ||
gr.Markdown('Teacher Video') | ||
teacher_video = gr.Video(type='mp4') | ||
gr.Examples([ | ||
os.path.join(project_path, 'resources/idol_producer.mp4') | ||
], teacher_video) | ||
button_upload = gr.Button('Upload', variant='primary') | ||
|
||
button = gr.Button('Grading', variant='primary') | ||
gr.Markdown('## Display') | ||
out_video = gr.Video() | ||
|
||
button_upload.click( | ||
partial(process_video), [teacher_video, student_video], out_video) | ||
button.click( | ||
partial(process_video), [teacher_video, student_video], out_video) | ||
|
||
gr.close_all() | ||
demo.queue() | ||
demo.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import numpy as np | ||
import torch | ||
|
||
flip_indices = np.array( | ||
[0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]) | ||
valid_indices = np.array([0] + list(range(5, 17))) | ||
|
||
|
||
@torch.no_grad() | ||
def _calculate_similarity(tch_kpts: np.ndarray, stu_kpts: np.ndarray): | ||
|
||
stu_kpts = torch.from_numpy(stu_kpts[:, None, valid_indices]) | ||
tch_kpts = torch.from_numpy(tch_kpts[None, :, valid_indices]) | ||
stu_kpts = stu_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1], | ||
stu_kpts.shape[2], 3) | ||
tch_kpts = tch_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1], | ||
stu_kpts.shape[2], 3) | ||
|
||
matrix = torch.stack((stu_kpts, tch_kpts), dim=4) | ||
if torch.cuda.is_available(): | ||
matrix = matrix.cuda() | ||
mask = torch.logical_and(matrix[:, :, :, 2, 0] > 0.3, | ||
matrix[:, :, :, 2, 1] > 0.3) | ||
matrix[~mask] = 0.0 | ||
|
||
matrix_ = matrix.clone() | ||
matrix_[matrix == 0] = 256 | ||
x_min = matrix_.narrow(3, 0, 1).min(dim=2).values | ||
y_min = matrix_.narrow(3, 1, 1).min(dim=2).values | ||
matrix_ = matrix.clone() | ||
# matrix_[matrix == 0] = 0 | ||
x_max = matrix_.narrow(3, 0, 1).max(dim=2).values | ||
y_max = matrix_.narrow(3, 1, 1).max(dim=2).values | ||
|
||
matrix_ = matrix.clone() | ||
matrix_[:, :, :, 0] = (matrix_[:, :, :, 0] - x_min) / ( | ||
x_max - x_min + 1e-4) | ||
matrix_[:, :, :, 1] = (matrix_[:, :, :, 1] - y_min) / ( | ||
y_max - y_min + 1e-4) | ||
matrix_[:, :, :, 2] = (matrix_[:, :, :, 2] > 0.3).float() | ||
xy_dist = matrix_[..., :2, 0] - matrix_[..., :2, 1] | ||
score = matrix_[..., 2, 0] * matrix_[..., 2, 1] | ||
|
||
similarity = (torch.exp(-50 * xy_dist.pow(2).sum(dim=-1)) * | ||
score).sum(dim=-1) / ( | ||
score.sum(dim=-1) + 1e-6) | ||
num_visible_kpts = score.sum(dim=-1) | ||
similarity = similarity * torch.log( | ||
(1 + (num_visible_kpts - 1) * 10).clamp(min=1)) / np.log(161) | ||
|
||
similarity[similarity.isnan()] = 0 | ||
|
||
return similarity | ||
|
||
|
||
@torch.no_grad() | ||
def calculate_similarity(tch_kpts: np.ndarray, stu_kpts: np.ndarray): | ||
assert tch_kpts.shape[1] == 17 | ||
assert tch_kpts.shape[2] == 3 | ||
assert stu_kpts.shape[1] == 17 | ||
assert stu_kpts.shape[2] == 3 | ||
|
||
similarity1 = _calculate_similarity(tch_kpts, stu_kpts) | ||
|
||
stu_kpts_flip = stu_kpts[:, flip_indices] | ||
stu_kpts_flip[..., 0] = 191.5 - stu_kpts_flip[..., 0] | ||
similarity2 = _calculate_similarity(tch_kpts, stu_kpts_flip) | ||
|
||
similarity = torch.stack((similarity1, similarity2)).max(dim=0).values | ||
|
||
return similarity | ||
|
||
|
||
@torch.no_grad() | ||
def select_piece_from_similarity(similarity): | ||
m, n = similarity.size() | ||
row_indices = torch.arange(m).view(-1, 1).expand(m, n).to(similarity) | ||
col_indices = torch.arange(n).view(1, -1).expand(m, n).to(similarity) | ||
diagonal_indices = similarity.size(0) - 1 - row_indices + col_indices | ||
unique_diagonal_indices, inverse_indices = torch.unique( | ||
diagonal_indices, return_inverse=True) | ||
|
||
diagonal_sums_list = torch.zeros( | ||
unique_diagonal_indices.size(0), | ||
dtype=similarity.dtype, | ||
device=similarity.device) | ||
diagonal_sums_list.scatter_add_(0, inverse_indices.view(-1), | ||
similarity.view(-1)) | ||
diagonal_sums_list[:min(m, n) // 4] = 0 | ||
diagonal_sums_list[-min(m, n) // 4:] = 0 | ||
index = diagonal_sums_list.argmax().item() | ||
|
||
similarity_smooth = torch.nn.functional.max_pool2d( | ||
similarity[None], (1, 11), stride=(1, 1), padding=(0, 5))[0] | ||
similarity_vec = similarity_smooth.diagonal(offset=index - m + | ||
1).cpu().numpy() | ||
|
||
stu_start = max(0, m - 1 - index) | ||
tch_start = max(0, index - m + 1) | ||
|
||
return dict( | ||
stu_start=stu_start, | ||
tch_start=tch_start, | ||
length=len(similarity_vec), | ||
similarity=similarity_vec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../configs/_base_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
_base_ = '../../rtmpose/rtmdet/person/rtmdet_nano_320-8xb32_coco-person.py' | ||
|
||
model = dict(test_cfg=dict(nms_pre=1, score_thr=0.0, max_per_img=1)) |
Oops, something went wrong.