forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AnimeGANv2 model (PaddlePaddle#102)
* add animeganv2 network and dataset * animegan:refine code,add License Co-authored-by: qingqing01 <[email protected]>
- Loading branch information
Showing
24 changed files
with
1,367 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import paddle | ||
import os | ||
import sys | ||
sys.path.insert(0, os.getcwd()) | ||
from ppgan.apps import AnimeGANPredictor | ||
import argparse | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input_image", type=str, help="path to source image") | ||
|
||
parser.add_argument("--output_path", | ||
type=str, | ||
default='output_dir', | ||
help="path to output image dir") | ||
|
||
parser.add_argument("--weight_path", | ||
type=str, | ||
default=None, | ||
help="path to model checkpoint path") | ||
|
||
parser.add_argument("--use_adjust_brightness", | ||
action="store_false", | ||
help="adjust brightness mode.") | ||
|
||
parser.add_argument("--cpu", | ||
dest="cpu", | ||
action="store_true", | ||
help="cpu mode.") | ||
|
||
args = parser.parse_args() | ||
|
||
if args.cpu: | ||
paddle.set_device('cpu') | ||
|
||
predictor = AnimeGANPredictor(args.output_path, args.weight_path, | ||
args.use_adjust_brightness) | ||
predictor.run(args.input_image) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
epochs: 30 | ||
output_dir: output_dir | ||
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams | ||
g_adv_weight: 300. | ||
d_adv_weight: 300. | ||
con_weight: 1.5 | ||
sty_weight: 2.5 | ||
color_weight: 10. | ||
tv_weight: 1. | ||
|
||
model: | ||
name: AnimeGANV2Model | ||
generator: | ||
name: AnimeGenerator | ||
discriminator: | ||
name: AnimeDiscriminator | ||
gan_mode: lsgan | ||
|
||
dataset: | ||
train: | ||
name: AnimeGANV2Dataset | ||
num_workers: 4 | ||
batch_size: 4 | ||
dataroot: data/animedataset | ||
style: Hayao | ||
phase: train | ||
direction: AtoB | ||
transform_real: | ||
- name: Transpose | ||
- name: Normalize | ||
mean: [127.5, 127.5, 127.5] | ||
std: [127.5, 127.5, 127.5] | ||
transform_anime: | ||
- name: Add | ||
value: [-4.4346957, -8.665916, 13.100612] | ||
- name: Transpose | ||
- name: Normalize | ||
mean: [127.5, 127.5, 127.5] | ||
std: [127.5, 127.5, 127.5] | ||
transform_gray: | ||
- name: Grayscale | ||
num_output_channels: 3 | ||
- name: Transpose | ||
- name: Normalize | ||
mean: [127.5, 127.5, 127.5] | ||
std: [127.5, 127.5, 127.5] | ||
test: | ||
name: SingleDataset | ||
dataroot: data/animedataset/test/HR_photo | ||
max_dataset_size: inf | ||
direction: BtoA | ||
input_nc: 3 | ||
output_nc: 3 | ||
serial_batches: False | ||
pool_size: 50 | ||
transforms: | ||
- name: ResizeToScale | ||
size: [256, 256] | ||
scale: 32 | ||
interpolation: bilinear | ||
- name: Transpose | ||
- name: Normalize | ||
mean: [127.5, 127.5, 127.5] | ||
std: [127.5, 127.5, 127.5] | ||
|
||
optimizer: | ||
name: Adam | ||
beta1: 0.5 | ||
|
||
lr_scheduler: | ||
name: linear | ||
learning_rate: 0.00002 | ||
start_epoch: 100 | ||
decay_epochs: 100 | ||
|
||
log_config: | ||
interval: 100 | ||
visiual_interval: 100 | ||
|
||
snapshot_config: | ||
interval: 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
epochs: 2 | ||
output_dir: output_dir | ||
con_weight: 1 | ||
pretrain_ckpt: null | ||
|
||
model: | ||
name: AnimeGANV2PreTrainModel | ||
generator: | ||
name: AnimeGenerator | ||
discriminator: | ||
name: AnimeDiscriminator | ||
gan_mode: lsgan | ||
|
||
dataset: | ||
train: | ||
name: AnimeGANV2Dataset | ||
num_workers: 4 | ||
batch_size: 4 | ||
dataroot: data/animedataset | ||
style: Hayao | ||
phase: train | ||
direction: AtoB | ||
transform_real: | ||
- name: Transpose | ||
- name: Normalize | ||
mean: [127.5, 127.5, 127.5] | ||
std: [127.5, 127.5, 127.5] | ||
transform_anime: | ||
- name: Add | ||
value: [-4.4346957, -8.665916, 13.100612] | ||
- name: Transpose | ||
- name: Normalize | ||
mean: [127.5, 127.5, 127.5] | ||
std: [127.5, 127.5, 127.5] | ||
transform_gray: | ||
- name: Grayscale | ||
num_output_channels: 3 | ||
- name: Transpose | ||
- name: Normalize | ||
mean: [127.5, 127.5, 127.5] | ||
std: [127.5, 127.5, 127.5] | ||
test: | ||
name: SingleDataset | ||
dataroot: data/animedataset/test/test_photo | ||
max_dataset_size: inf | ||
direction: BtoA | ||
input_nc: 3 | ||
output_nc: 3 | ||
serial_batches: False | ||
pool_size: 50 | ||
transforms: | ||
- name: Resize | ||
size: [256, 256] | ||
interpolation: "bicubic" #cv2.INTER_CUBIC | ||
- name: Transpose | ||
- name: Normalize | ||
mean: [127.5, 127.5, 127.5] | ||
std: [127.5, 127.5, 127.5] | ||
|
||
optimizer: | ||
name: Adam | ||
beta1: 0.5 | ||
|
||
lr_scheduler: | ||
name: linear | ||
learning_rate: 0.0002 | ||
start_epoch: 100 | ||
decay_epochs: 100 | ||
|
||
log_config: | ||
interval: 100 | ||
visiual_interval: 100 | ||
|
||
snapshot_config: | ||
interval: 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# 1 AnimeGANv2 | ||
|
||
## 1.1 Introduction | ||
|
||
[AnimeGAN](https://github.com/TachibanaYoshino/AnimeGANv2) improved the [CVPR paper CartoonGAN](https://openaccess.thecvf.com/content_cvpr_2018/papers/Chen_CartoonGAN_Generative_Adversarial_CVPR_2018_paper.pdf), mainly to solve the over-stylized and color artifact area. For the details, you can refer to the [Zhihu article](https://zhuanlan.zhihu.com/p/76574388?from_voters_page=true) writes by the paper author.Based on the AnimeGAN, the AnimeGANv2 add the `total variation loss` in the generator loss. | ||
|
||
|
||
## 1.2 How to use | ||
|
||
### 1.2.1 Quick start | ||
|
||
After installing PaddleGAN, you can run python code as follows to generate the stylized image. Where the `PATH_OF_IMAGE` is your source image path. | ||
|
||
```python | ||
from ppgan.apps import AnimeGANPredictor | ||
predictor = AnimeGANPredictor() | ||
predictor.run(PATH_OF_IMAGE) | ||
``` | ||
|
||
Or run such a command to get the same result: | ||
|
||
```sh | ||
python applications/tools/animeganv2.py --input_image ${PATH_OF_IMAGE} | ||
``` | ||
|
||
### 1.2.1 Prepare dataset | ||
|
||
We download the dataset provided by the author from [here](https://github.com/TachibanaYoshino/AnimeGAN/releases/tag/dataset-1).Then unzip to the `data` directory. | ||
|
||
```sh | ||
wget https://github.com/TachibanaYoshino/AnimeGAN/releases/download/dataset-1/dataset.zip | ||
cd PaddleGAN | ||
unzip YOUR_DATASET_DIR/dataset.zip -d data/animedataset | ||
``` | ||
|
||
For example, the structure of `animedataset` is as following: | ||
|
||
```sh | ||
animedataset | ||
├── Hayao | ||
│ ├── smooth | ||
│ └── style | ||
├── Paprika | ||
│ ├── smooth | ||
│ └── style | ||
├── Shinkai | ||
│ ├── smooth | ||
│ └── style | ||
├── SummerWar | ||
│ ├── smooth | ||
│ └── style | ||
├── test | ||
│ ├── HR_photo | ||
│ ├── label_map | ||
│ ├── real | ||
│ ├── test_photo | ||
│ └── test_photo256 | ||
├── train_photo | ||
└── val | ||
``` | ||
|
||
### 1.2.2 Training | ||
|
||
An example is training to Hayao stylize. | ||
|
||
1. To ensure the generator can generate the original image, we need to warmup the model.: | ||
```sh | ||
python tools/main.py --config-file configs/animeganv2_pretrain.yaml | ||
``` | ||
|
||
2. After the warmup, we strat to training GAN.: | ||
**NOTE:** you must modify the `configs/animeganv2.yaml > pretrain_ckpt ` parameter first! ensure the GAN can reuse the warmup generator model. | ||
Set the `batch size=4` and the `learning rate=0.00002`. Train 30 epochs on a GTX2060S GPU to reproduce the result. For other hyperparameters, please refer to `configs/animeganv2.yaml`. | ||
```sh | ||
python tools/main.py --config-file configs/animeganv2.yaml | ||
``` | ||
|
||
3. Change target style | ||
Modify `style` parameter in the `configs/animeganv2.yaml`, now support choice from `Hayao, Paprika, Shinkai, SummerWar`. If you want to use your own dataset, you can modify it to be your own in the configuration file. | ||
|
||
**NOTE :** After modifying the target style, calculate the mean value of the target style dataset at first, and the `transform_anime->Add->value` parameter in `configs/animeganv2.yaml` must be modified. | ||
|
||
The following example shows how to obtain the mean value of the `Hayao` style: | ||
```sh | ||
python tools/animegan_picmean.py --dataset data/animedataset/Hayao/style | ||
image_num: 1792 | ||
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1792/1792 [00:04<00:00, 444.95it/s] | ||
RGB mean diff | ||
[-4.4346957 -8.665916 13.100612 ] | ||
``` | ||
|
||
|
||
### 1.2.3 Test | ||
|
||
test model on `data/animedataset/test/HR_photo` | ||
```sh | ||
python tools/main.py --config-file configs/animeganv2.yaml --evaluate-only --load ${PATH_OF_WEIGHT} | ||
``` | ||
|
||
## 1.3 Results | ||
| original image | style image | | ||
| ----------------------------------- | ---------------------------------- | | ||
|  |  | |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.