Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Robust Video Matting #223

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added images/rvm_feature.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/rvm_logo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
99 changes: 99 additions & 0 deletions peterl1n_robustvideomatting.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
---
layout: hub_detail
background-class: hub-background
body-class: hub
category: researchers
title: Robust Video Matting
summary: RVM is a matting network for background replacement tasks. It is trained on human subjects and supports high resolutions such as 4K/HD.
image: rvm_logo.gif
author: Peter Lin
tags: [vision, scriptable]
github-link: https://github.com/PeterL1n/RobustVideoMatting
github-id: PeterL1n/RobustVideoMatting
featured_image_1: rvm_feature.gif
featured_image_2: no-image
accelerator: cuda-optional
---

### Model Description

Robust Video Matting (RVM) is human matting network proposed in paper ["Robust High-Resolution Video Matting with Temporal Guidance"](https://peterl1n.github.io/RobustVideoMatting/). It can process high resolutions such as 4K/HD in real-time (depending on the GPU). It is also natively designed for processing videos by using a recurrent architecture to exploit temporal information.

### Demo

Watch the showreel video on [YouTube](https://youtu.be/Jvzltozpbpk) or [Bilibili](https://www.bilibili.com/video/BV1Z3411B7g7/) to see how it performs!


### Load the Model

Make sure you have `torch>=1.8.1` and `torchvision>=0.9.1` installed.

```python
import torch

model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3").cuda().eval()
```

We recommend using the `mobilenetv3` backbone for most cases. We also provide a `resnet50` backbone with slight improvement in performance.

### Use Inference API

We provide a simple inference function `convert_video`. You can use it to apply our model on your videos.

```sh
pip install av pims tqdm
```

```python
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
```

```python
convert_video(
model, # The loaded model, can be on any device (cpu or cuda).
input_source='input.mp4', # A video file or an image sequence directory.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to just download the same demo video as in https://colab.research.google.com/drive/10z-pNKRnVNsp0Lq9tH1J_XPZ7CBC_uHm?usp=sharing#scrollTo=-iJwFwqUI9Az?

This way, if running the model doesn't take too long, we'd be able to properly test the file instead of skipping it

input_resize=None, # [Optional] Resize the input to (width, height).
downsample_ratio=None, # [Optional] Advanced hyperparameter. See inference doc.
output_type='video', # Choose "video" or "png_sequence"
output_composition='com.mp4', # File path if video; directory path if png sequence.
output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction.
output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction.
output_video_mbps=4, # Output video mbps. Not needed for png sequence.
seq_chunk=4, # Process n frames at once for better parallelism.
num_workers=0, # Only for image sequence input.
progress=True # Print conversion progress.
)
```

More advanced usage, see [inference documentation](https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/inference.md). Note that the inference function will likely not be real-time for HD/4K because video encoding/decoding are not done with hardware acceleration. You can implement your own inference loop as shown below.

### Implement Custom Inference Loops

```python
bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background.
rec = [None] * 4 # Initial recurrent states.
downsample_ratio = 0.25 # Adjust based on your video.

with torch.no_grad():
for src in __YOUR_VIDEO_READER__: # RGB video frame normalized to 0 ~ 1.
fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # Cycle the recurrent states.
out = fgr * pha + bgr * (1 - pha) # Composite to green background.
__YOUR_VIDEO_WRITER__(out)
```

* Since RVM uses a recurrent architecture, we initialize the recurrent states `rec` to `None` initially, and cycle the recurrent states when processing every frame.
* `src` can be `BCHW`, or `BTCHW` where `T` is a chunk of frames that can be given to the model at once to improve parallalism.
* `downsample_ratio` is an important hyperparameter. See [inference documentation](https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/inference.md) for more detail.

### References

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On top of the bibtex, would you mind adding a link to the archiv page of the paper as well?

```
@misc{rvm,
title={Robust High-Resolution Video Matting with Temporal Guidance},
author={Shanchuan Lin and Linjie Yang and Imran Saleemi and Soumyadip Sengupta},
year={2021},
eprint={2108.11515},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
2 changes: 2 additions & 0 deletions test_run_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def test_run_file(file_path):
pytest.skip("temporarily disabled")
if 'pytorch_fairseq_roberta' in file_path:
pytest.skip("temporarily disabled")
if 'robustvideomatting' in file_path:
pytest.skip("temporarily disabled")

# We just run the python files in a separate sub-process. We really want a
# subprocess here because otherwise we might run into package versions
Expand Down