-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
27 lines (24 loc) · 872 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from models.GAN import GeneativeModel, Discriminator, gan_train
from utils.font_test import common_han
import numpy as np
import torch
# Load source fonts & target fonts
datasets = np.load("./fonts/font.npz")
source_fonts = datasets['source_fonts']
target_fonts = datasets['target_fonts']
# Load category embedded layers
category_ = np.load("./category_emb.npz")
category_emb = {}
category_emb['cl1'] = torch.Tensor(category_['cl1'])
category_emb['cl2'] = torch.Tensor(category_['cl2'])
category_emb['cl3'] = torch.Tensor(category_['cl3'])
category_emb['cl4'] = torch.Tensor(category_['cl4'])
category_emb['cl5'] = torch.Tensor(category_['cl5'])
category_emb['cl6'] = torch.Tensor(category_['cl6'])
generator = GeneativeModel()
discriminator = Discriminator()
gan_train(generator,
discriminator,
source_fonts,
target_fonts,
category_emb)