Skip to content

Commit

Permalink
fix fid (PaddlePaddle#336)
Browse files Browse the repository at this point in the history
* fix fid

* fix fid

* add pixel2pixel facades model
  • Loading branch information
lzzyzlbb authored Jun 4, 2021
1 parent bab376f commit 1f335bb
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 10 deletions.
10 changes: 10 additions & 0 deletions configs/pix2pix_cityscapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,13 @@ log_config:

snapshot_config:
interval: 5

validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8


10 changes: 10 additions & 0 deletions configs/pix2pix_cityscapes_2gpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,13 @@ log_config:

snapshot_config:
interval: 5

validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8


10 changes: 10 additions & 0 deletions configs/pix2pix_facades.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,13 @@ log_config:

snapshot_config:
interval: 5

validate:
interval: 500
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8


1 change: 1 addition & 0 deletions docs/en_US/tutorials/pix2pix_cyclegan.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
| 模型 | 数据集 | 下载地址 |
|---|---|---|
| Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams)
| Pix2Pix_facedes | facades | [Pix2Pix_facades](https://paddlegan.bj.bcebos.com/models/Pixel2Pixel_facades.pdparams)



Expand Down
1 change: 1 addition & 0 deletions docs/zh_CN/tutorials/pix2pix_cyclegan.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
| 模型 | 数据集 | 下载地址 |
|---|---|---|
| Pix2Pix_cityscapes | cityscapes | [Pix2Pix_cityscapes](https://paddlegan.bj.bcebos.com/models/Pix2Pix_cityscapes.pdparams)
| Pix2Pix_facedes | facades | [Pix2Pix_facades](https://paddlegan.bj.bcebos.com/models/Pixel2Pixel_facades.pdparams)


# 2 CycleGAN
Expand Down
26 changes: 17 additions & 9 deletions ppgan/metrics/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,30 @@ def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None,
premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
self.model = model
param_dict = paddle.load(premodel_path)
model.load_dict(param_dict)
model.eval()
self.model.load_dict(param_dict)
self.model.eval()
self.reset()

def reset(self):
self.preds = []
self.gts = []
self.results = []

def update(self, preds, gts):
value = calculate_fid_given_img(preds, gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.results.append(value)

if len(preds.shape) >=4:
self.preds.append(preds)
self.gts.append(gts)
else:
for i in range(preds.shape[0]):
self.preds.append(preds[i,:,:,:,:])
self.gts.append(gts[i,:,:,:,:])

def accumulate(self):
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
self.preds = paddle.concat(self.preds, axis=0)
self.gts = paddle.concat(self.gts, axis=0)
value = calculate_fid_given_img(self.preds, self.gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.reset()
return value

def name(self):
return 'FID'
Expand Down Expand Up @@ -123,7 +132,6 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
images = img[start:end]
if images.shape[1] != 3:
images = images.transpose((0, 3, 1, 2))
images /= 255

images = paddle.to_tensor(images)
pred = model(images)[0][0]
Expand Down
7 changes: 6 additions & 1 deletion ppgan/models/pix2pix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,10 @@ def train_iter(self, optimizers=None):
optimizers['optimG'].step()

def test_iter(self, metrics=None):
self.nets['netG'].eval()
self.forward()
with paddle.no_grad():
self.forward()
if metrics is not None:
for metric in metrics.values():
metric.update(self.fake_B, self.real_B)
self.nets['netG'].train()

0 comments on commit 1f335bb

Please sign in to comment.