From e1f213cb0bb49148b517813690400ea3cc2285c3 Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Sat, 7 Oct 2023 17:12:38 +0200 Subject: [PATCH] fix: end of training metrics computation --- train.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 3c736613d..32ce52c10 100644 --- a/train.py +++ b/train.py @@ -358,18 +358,27 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): ###Let's compute final FID if rank_0 and opt.train_compute_metrics_test: - cur_fid = model.compute_metrics_test() + with torch.no_grad(): + if use_temporal: + dataloaders_test = zip(dataloader_test, dataloader_test_temporal) + else: + dataloaders_test = zip(dataloader_test) + model.compute_metrics_test( + dataloaders_test, opt.train_epoch_count - 1, total_iters + ) + cur_metrics = model.get_current_metrics() path_json = os.path.join(opt.checkpoints_dir, opt.name, "eval_results.json") - if os.path.exists(path_json): with open(path_json, "r") as loadfile: data = json.load(loadfile) with open(path_json, "w+") as outfile: data = {} - data["fid_%s_img_%s_epochs" % (opt.data_max_dataset_size, epoch)] = float( - cur_fid.item() - ) + for key, value in cur_metrics.items(): + data[ + "%s_%s_img_%s" + % (key, opt.data_max_dataset_size, opt.train_epoch_count) + ] = float(value) json.dump(data, outfile) if rank_0: