diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 92081bec6..774607645 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -901,38 +901,41 @@ def validation_step(self, batch, batch_idx): x, y = batch['image'], batch['target'] pred, _ = self.nn.nn(x) # scale target to output size - y = F.interpolate(y, size=(pred.size(2), pred.size(3))).int() - - self.val_px_accuracy.update(pred, y) - self.val_mean_accuracy.update(pred, y) - self.val_mean_iu.update(pred, y) - self.val_freq_iu.update(pred, y) + y = F.interpolate(y, size=(pred.size(2), pred.size(3)), mode='nearest').int() + # Get regions for IoU metrics + reg_idxs = sorted(self.nn.user_metadata['class_mapping']['regions'].values()) + pred_reg = [:, reg_idxs, ...] + y_reg = y[:, reg_idxs, ...] + self.val_region_px_accuracy.update(pred_reg, y_reg) + self.val_region_mean_accuracy.update(pred_reg, y_reg) + self.val_region_mean_iu.update(pred_reg, y_reg) + self.val_region_freq_iu.update(pred_reg, y_reg) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: - pixel_accuracy = self.val_px_accuracy.compute() - mean_accuracy = self.val_mean_accuracy.compute() - mean_iu = self.val_mean_iu.compute() - freq_iu = self.val_freq_iu.compute() + pixel_accuracy = self.val_region_px_accuracy.compute() + mean_accuracy = self.val_region_mean_accuracy.compute() + mean_iu = self.val_region_mean_iu.compute() + freq_iu = self.val_region_freq_iu.compute() if mean_iu > self.best_metric: - logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') + logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') self.best_epoch = self.current_epoch self.best_metric = mean_iu logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}') - self.log('val_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) # reset metrics even if sanity checking - self.val_px_accuracy.reset() - self.val_mean_accuracy.reset() - self.val_mean_iu.reset() - self.val_freq_iu.reset() + self.val_region_px_accuracy.reset() + self.val_region_mean_accuracy.reset() + self.val_region_mean_iu.reset() + self.val_region_freq_iu.reset() def setup(self, stage: Optional[str] = None): # finalize models in case of appending/loading @@ -1055,10 +1058,14 @@ def setup(self, stage: Optional[str] = None): torch.set_num_threads(max(self.num_workers, 1)) # set up validation metrics after output classes have been determined - self.val_px_accuracy = MultilabelAccuracy(average='micro', num_labels=self.train_set.dataset.num_classes) - self.val_mean_accuracy = MultilabelAccuracy(average='macro', num_labels=self.train_set.dataset.num_classes) - self.val_mean_iu = MultilabelJaccardIndex(average='macro', num_labels=self.train_set.dataset.num_classes) - self.val_freq_iu = MultilabelJaccardIndex(average='weighted', num_labels=self.train_set.dataset.num_classes) + # baseline metrics + # region metrics + num_regions = len(self.val_set.dataset.class_mapping['regions']) + self.val_region_px_accuracy = MultilabelAccuracy(average='micro', num_labels=num_regions) + self.val_region_mean_accuracy = MultilabelAccuracy(average='macro', num_labels=num_regions) + self.val_region_mean_iu = MultilabelJaccardIndex(average='macro', num_labels=num_regions) + self.val_region_freq_iu = MultilabelJaccardIndex(average='weighted', num_labels=num_regions) + def train_dataloader(self): return DataLoader(self.train_set,