Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StyleGAN training #1446

Merged
merged 45 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e12934e
scripts/segmentation/train.py
xdeng7 Apr 28, 2020
c233e03
Merge remote-tracking branch 'upstream/master'
xdeng7 Apr 28, 2020
04a8625
gluoncv/loss.py
xdeng7 Apr 28, 2020
06b19e8
attention.py
xdeng7 Apr 29, 2020
4288833
danet.py
xdeng7 Apr 29, 2020
f730cec
Merge remote-tracking branch 'upstream/master'
xdeng7 May 1, 2020
b5c7dc0
update gluon-cv
xdeng7 Jul 8, 2020
8ed3918
fix bug
xdeng7 Jul 8, 2020
ee82f33
fix bug
xdeng7 Jul 8, 2020
09267fb
add stylegan
xdeng7 Aug 2, 2020
7a71d1d
fix demo name
xdeng7 Aug 2, 2020
a76988a
Delete demo.py
xdeng7 Aug 2, 2020
c40d798
Merge remote-tracking branch 'upstream/master'
xdeng7 Aug 2, 2020
3ea3e7f
Merge branch 'master' into debug
xdeng7 Aug 2, 2020
e505539
Merge branch 'debug' of https://github.com/xdeng7/gluon-cv into debug
xdeng7 Aug 2, 2020
0d16840
fix debug
xdeng7 Aug 2, 2020
ebb3a41
fix bug
xdeng7 Aug 2, 2020
972a87e
fix debug
xdeng7 Aug 2, 2020
625b947
add README
xdeng7 Aug 2, 2020
229d518
stylegan: README
xdeng7 Aug 2, 2020
01cc1df
Add sample images
xdeng7 Aug 2, 2020
58e0f79
fix README
xdeng7 Aug 2, 2020
626d562
Merge branch 'debug' of https://github.com/xdeng7/gluon-cv into debug
xdeng7 Aug 2, 2020
bb88351
fix README
xdeng7 Aug 2, 2020
974aac5
add description
xdeng7 Aug 2, 2020
fddd6f2
fix README
xdeng7 Aug 2, 2020
a1b5715
Merge remote-tracking branch 'upstream/master'
xdeng7 Aug 4, 2020
dcc7a61
Merge branch 'master' into debug
xdeng7 Aug 4, 2020
e16b572
fix final
xdeng7 Aug 4, 2020
17949e0
fix README
xdeng7 Aug 4, 2020
dd4628c
add training
xdeng7 Sep 15, 2020
2e21e42
add training
xdeng7 Sep 16, 2020
61fa6a9
stylegan: add training
xdeng7 Sep 16, 2020
cb2f582
stylegan: add prepare data
xdeng7 Sep 16, 2020
ae9e102
stylegan: modity README
xdeng7 Sep 16, 2020
58e3f78
stylegan: add training sample image
xdeng7 Sep 16, 2020
3d65a06
stylegan: fix README
xdeng7 Sep 16, 2020
8458d6d
delete image
xdeng7 Sep 16, 2020
f00e31e
Merge branch 'debug' of https://github.com/xdeng7/gluon-cv into debug
xdeng7 Sep 16, 2020
4ba61fd
fix README
xdeng7 Sep 16, 2020
a2bf4d7
add image
xdeng7 Sep 16, 2020
4ae90c0
Merge branch 'master' into debug
bryanyzhu Sep 16, 2020
c7aff70
fix README
xdeng7 Sep 17, 2020
87407ec
stylegan: fix minor changes
xdeng7 Sep 17, 2020
79ec92b
stylegan:fix minor changes
xdeng7 Sep 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion scripts/gan/stylegan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,24 @@

**Train StyleGAN**

Instructions coming soon.
First, prepare the dataset for training. Download the FFHQ dataset from (https://github.com/NVlabs/ffhq-dataset) and save it to DATASET_PATH.
LMDB_PATH: directory to save the output dataset, N_WORKER: number of workers, DATASET_PATH: the downloaed FFHQ images folder path
```bash
python prepare_data.py --out LMDB_PATH --n_worker N_WORKER DATASET_PATH
```

Second, train the StyleGAN for FFHQ dataset.
```bash
python train.py --path LMDB_PATH --sched
```

*Notes for training*
1) The original tensorflow implementation can't be 100% converted by MXNet. Two functions are missing, [gradient penalty](https://github.com/NVlabs/stylegan/blob/66813a32aac5045fcde72751522a0c0ba963f6f2/training/loss.py#L50) and [blur](https://github.com/NVlabs/stylegan/blob/66813a32aac5045fcde72751522a0c0ba963f6f2/training/networks_stylegan.py#L96). The lack of gradient penalty can cause mode collapse while training, so it is neccessary to tune the learning rate based on the number of GPUs and apply early stop. The lack of blur function results in the low image quality and this is one of the important reasons that high-resolution images can't be generated via our implementation.
2) The training of StyleGAN is not stable at this moment due to the aforementioned reasons. We've tested the training by using 8 K80 GPUs and single GPU. Single GPU can be problematic. The following images are generated by a model trained with 8 K80 GPUs.
3) It takes around 4 days with 8 K80 GPUs to train a StyleGAN to generate 128x128 images.

![images](sample_train.png "Generated 128x128 FFHQ images from the trained StyleGAN")


**Test StyleGAN**

Expand Down
15 changes: 10 additions & 5 deletions scripts/gan/stylegan/demo_stylegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def save_image(data, file, normalize=True, img_range=None):
img_range = [min(data), max(data)]

norm_img = normalize_image(data, img_range[0], img_range[1])
img = nd.clip(norm_img * 255 + 0.5, 0, 255).asnumpy().astype(np.uint8)
img = nd.clip(norm_img * 255 + 0.5, 0, 255).asnumpy().astype(np.uint8)

img = Image.fromarray(np.transpose(img, (1, 2, 0)))
img.save(file)

Expand All @@ -61,16 +62,18 @@ def save_image(data, file, normalize=True, img_range=None):
parser.add_argument('--n_sample', type=int, default=10, help='number of rows of sample matrix')
parser.add_argument('--gpu_id', type=str, default='0', help='gpu id: e.g. 0. use -1 for CPU')
parser.add_argument('--out_dir', type=str, default='samples/', help='output directory for samples')
parser.add_argument('--path', type=str, default='./stylegan-ffhq-1024px-new.params',
parser.add_argument('--path', type=str, default='./stylegan-ffhq-1024px-new.params',
help='path to checkpoint file')

args = parser.parse_args()

args = parser.parse_args()
if args.gpu_id == '-1':
device = mx.cpu()
else:
device = mx.gpu(int(args.gpu_id.strip()))

generator = StyledGenerator(code_dim=512)
generator = StyledGenerator(512, blur=True)

generator.initialize()
generator.collect_params().reset_ctx(device)
generator.load_parameters(args.path, ctx=device)
Expand All @@ -82,7 +85,9 @@ def save_image(data, file, normalize=True, img_range=None):
imgs = sample(generator, step, mean_style, args.n_sample, device)

if not os.path.isdir(args.out_dir):
os.makedirs(args.out_dir)
os.makedirs(args.out_dir)

for i in range(args.n_sample):
save_image(imgs[i], os.path.join(args.out_dir, 'sample_{}.png'.format(i)), normalize=True, img_range=(-1, 1))


Loading