Skip to content

Commit

Permalink
fix bug in saving scores
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottwu committed Apr 17, 2020
1 parent 2865226 commit b756264
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions unsup3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def save_results(self, save_dir):
utils.save_videos(save_dir, canon_im_rotate_grid, suffix='image_video', sep_folder=sep_folder, cycle=True)
utils.save_videos(save_dir, canon_normal_rotate_grid, suffix='normal_video', sep_folder=sep_folder, cycle=True)

# save gt and scores if gt is loaded
# save scores if gt is loaded
if self.load_gt_depth:
depth_gt = ((self.depth_gt[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy()
normal_gt = self.normal_gt[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5
Expand All @@ -386,12 +386,14 @@ def save_results(self, save_dir):
self.all_scores = torch.cat([self.all_scores, all_scores], 0)

def save_scores(self, path):
header = 'MAE_masked, \
MSE_masked, \
SIE_masked, \
NorErr_masked'
mean = self.all_scores.mean(0)
std = self.all_scores.std(0)
header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean])
header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std])
utils.save_scores(path, self.all_scores, header=header)
# save scores if gt is loaded
if self.load_gt_depth:
header = 'MAE_masked, \
MSE_masked, \
SIE_masked, \
NorErr_masked'
mean = self.all_scores.mean(0)
std = self.all_scores.std(0)
header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean])
header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std])
utils.save_scores(path, self.all_scores, header=header)

0 comments on commit b756264

Please sign in to comment.