Skip to content

Commit

Permalink
Compute pixel metrics only on regions
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 27, 2024
1 parent b8c4c2b commit 5b9e2a0
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions kraken/lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,38 +901,41 @@ def validation_step(self, batch, batch_idx):
x, y = batch['image'], batch['target']
pred, _ = self.nn.nn(x)
# scale target to output size
y = F.interpolate(y, size=(pred.size(2), pred.size(3))).int()

self.val_px_accuracy.update(pred, y)
self.val_mean_accuracy.update(pred, y)
self.val_mean_iu.update(pred, y)
self.val_freq_iu.update(pred, y)
y = F.interpolate(y, size=(pred.size(2), pred.size(3)), mode='nearest').int()
# Get regions for IoU metrics
reg_idxs = sorted(self.nn.user_metadata['class_mapping']['regions'].values())
pred_reg = [:, reg_idxs, ...]
y_reg = y[:, reg_idxs, ...]
self.val_region_px_accuracy.update(pred_reg, y_reg)
self.val_region_mean_accuracy.update(pred_reg, y_reg)
self.val_region_mean_iu.update(pred_reg, y_reg)
self.val_region_freq_iu.update(pred_reg, y_reg)

def on_validation_epoch_end(self):
if not self.trainer.sanity_checking:
pixel_accuracy = self.val_px_accuracy.compute()
mean_accuracy = self.val_mean_accuracy.compute()
mean_iu = self.val_mean_iu.compute()
freq_iu = self.val_freq_iu.compute()
pixel_accuracy = self.val_region_px_accuracy.compute()
mean_accuracy = self.val_region_mean_accuracy.compute()
mean_iu = self.val_region_mean_iu.compute()
freq_iu = self.val_region_freq_iu.compute()

if mean_iu > self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})')
logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = mean_iu

logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}')

self.log('val_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_region_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True)

# reset metrics even if sanity checking
self.val_px_accuracy.reset()
self.val_mean_accuracy.reset()
self.val_mean_iu.reset()
self.val_freq_iu.reset()
self.val_region_px_accuracy.reset()
self.val_region_mean_accuracy.reset()
self.val_region_mean_iu.reset()
self.val_region_freq_iu.reset()

def setup(self, stage: Optional[str] = None):
# finalize models in case of appending/loading
Expand Down Expand Up @@ -1055,10 +1058,14 @@ def setup(self, stage: Optional[str] = None):
torch.set_num_threads(max(self.num_workers, 1))

# set up validation metrics after output classes have been determined
self.val_px_accuracy = MultilabelAccuracy(average='micro', num_labels=self.train_set.dataset.num_classes)
self.val_mean_accuracy = MultilabelAccuracy(average='macro', num_labels=self.train_set.dataset.num_classes)
self.val_mean_iu = MultilabelJaccardIndex(average='macro', num_labels=self.train_set.dataset.num_classes)
self.val_freq_iu = MultilabelJaccardIndex(average='weighted', num_labels=self.train_set.dataset.num_classes)
# baseline metrics
# region metrics
num_regions = len(self.val_set.dataset.class_mapping['regions'])
self.val_region_px_accuracy = MultilabelAccuracy(average='micro', num_labels=num_regions)
self.val_region_mean_accuracy = MultilabelAccuracy(average='macro', num_labels=num_regions)
self.val_region_mean_iu = MultilabelJaccardIndex(average='macro', num_labels=num_regions)
self.val_region_freq_iu = MultilabelJaccardIndex(average='weighted', num_labels=num_regions)


def train_dataloader(self):
return DataLoader(self.train_set,
Expand Down

0 comments on commit 5b9e2a0

Please sign in to comment.