Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

segmentation and depth ds #8

Merged
merged 14 commits into from
Aug 25, 2020
7 changes: 7 additions & 0 deletions configs/meta_arch/depth-segmentator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _group_

pl_class:
_target_: sharpf.modeling.DepthSegmentator
monitor: val_balanced_accuracy
loss:
_target_: torch.nn.functional.binary_cross_entropy_with_logits
19 changes: 19 additions & 0 deletions configs/model/unet-segmentator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _group_
model_name: PixelSegmentator
params:
feature_extractor:
_target_: sharpf.modeling.Unet
params:
encoder_name: resnet50
decoder_use_batchnorm: true
decoder_channels: [256, 128, 64, 32, 16]
decoder_attention_type: null
in_channels: 1
segmentation_head:
- _target_: torch.nn.Conv2d
params:
in_channels: 16
out_channels: 1
kernel_size: 3
padding: 1
- _target_: torch.nn.Sigmoid
15 changes: 15 additions & 0 deletions configs/transforms/depth-norm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# @package _group_

train:
- _target_: sharpf.utils.abc_utils.torch.TypeCast

val:
- _target_: sharpf.utils.abc_utils.torch.TypeCast

test:
- _target_: sharpf.utils.abc_utils.torch.TypeCast

normalisation:
- data: ['standartize', 'quantile']


3 changes: 2 additions & 1 deletion sharpf/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .meta_arch import PointSharpnessRegressor, DepthRegressor
from .meta_arch import PointSharpnessRegressor, DepthRegressor, DepthSegmentator
from .metrics import balanced_accuracy
from .model import MODEL_REGISTRY, build_model, DGCNN, Unet, PixelRegressor
from .modules import (
AggregationMax,
Expand Down
1 change: 1 addition & 0 deletions sharpf/modeling/meta_arch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .point_sharpness_regressor import PointSharpnessRegressor
from .depth_regressor import DepthRegressor
from .depth_segmentator import DepthSegmentator
19 changes: 11 additions & 8 deletions sharpf/modeling/meta_arch/depth_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sharpf.utils.comm import get_batch_size
from ..model.build import build_model
from ...data import DepthMapIO
from ...utils.abc_utils import LotsOfHdf5Files
from ...utils.abc_utils.hdf5.dataset import LotsOfHdf5Files, DepthDataset
from ...utils.abc_utils.torch import CompositeTransform

log = logging.getLogger(__name__)
Expand All @@ -27,7 +27,8 @@ class DepthRegressor(LightningModule):

def __init__(self, cfg):
super().__init__()
self.hparams = cfg
self.hparams = cfg # there should be better official way later
self.task = 'regression'
self.model = build_model(self.hparams.model)
self.example_input_array = torch.rand(1, 1, 64, 64)
self.data_dir = hydra.utils.to_absolute_path(self.hparams.data.data_dir)
Expand All @@ -43,7 +44,7 @@ def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):
points, distances = batch['image'], batch['distances']
points, distances = batch['image'], batch['distance_to_sharp']
points = points.unsqueeze(1) if points.dim() == 3 else points
preds = self.forward(points)
loss = hydra.utils.instantiate(self.hparams.meta_arch.loss, preds, distances)
Expand All @@ -52,7 +53,7 @@ def training_step(self, batch, batch_idx):
return result

def _shared_eval_step(self, batch, batch_idx, prefix):
points, distances = batch['image'], batch['distances']
points, distances = batch['image'], batch['distance_to_sharp']
points = points.unsqueeze(1) if points.dim() == 3 else points
preds = self.forward(points)

Expand Down Expand Up @@ -94,12 +95,14 @@ def configure_optimizers(self):
def _get_dataset(self, partition):
if hasattr(self, f'{partition}_set') and getattr(self, f'{partition}_set') is not None:
return getattr(self, f'{partition}_set')
transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.hparams.transforms[partition]])
return LotsOfHdf5Files(

transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.cfg.transforms[partition]])
return DepthDataset(
data_dir=self.data_dir,
io=DepthMapIO,
data_label=self.hparams.data.data_label,
target_label=self.hparams.data.target_label,
data_label=self.cfg.data.data_label,
target_label=self.cfg.data.target_label,
task=self.task,
partition=partition,
transform=transform,
max_loaded_files=self.hparams.data.max_loaded_files
Expand Down
137 changes: 137 additions & 0 deletions sharpf/modeling/meta_arch/depth_segmentator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import logging

import hydra
import torch
import torch.nn as nn
from pytorch_lightning import TrainResult
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.metrics.functional import stat_scores
from torch.utils.data import DataLoader

from sharpf.utils.comm import get_batch_size, all_gather, synchronize
from ..metrics import balanced_accuracy
from ..model.build import build_model
from ...data import DepthMapIO
from ...utils.abc_utils.hdf5 import DepthDataset
from ...utils.abc_utils.torch import CompositeTransform

log = logging.getLogger(__name__)


class DepthSegmentator(LightningModule):

def __init__(self, cfg):
super().__init__()
self.hparams = cfg
self.task = 'segmentation'
self.model = build_model(self.hparams.model)
self.example_input_array = torch.rand(1, 1, 64, 64)
self.data_dir = hydra.utils.to_absolute_path(self.hparams.data.data_dir)

dist_backend = self.hparams.trainer.distributed_backend
if (dist_backend is not None and 'ddp' in dist_backend) or (
dist_backend is None and self.hparams.trainer.gpus is not None and (
self.hparams.trainer.gpus > 1 or self.hparams.trainer.num_nodes > 1)):
log.info('Converting BatchNorm to SyncBatchNorm. Do not forget other batch-dimension dependent operations.')
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)

def forward(self, x, as_mask=True):
out = self.model(x)
if as_mask:
return (out.sigmoid() > 0.5).long()
return self.model(x)

def training_step(self, batch, batch_idx):
points, target = batch['image'], batch['close_to_sharp_mask']
points = points.unsqueeze(1) if points.dim() == 3 else points
preds = self.forward(points, as_mask=False)
loss = hydra.utils.instantiate(self.hparams.meta_arch.loss, preds, target)
result = TrainResult(minimize=loss)
result.log('train_loss', loss, prog_bar=True)
return result

def _shared_eval_step(self, batch, batch_idx, prefix):
points, target = batch['image'], batch['close_to_sharp_mask']
points = points.unsqueeze(1) if points.dim() == 3 else points
preds = self.forward(points, as_mask=True)
stats = [list(stat_scores(preds[i], target[i], class_index=1)) for i in range(preds.size(0))]
tp, fp, tn, fn, sup = torch.Tensor(stats).to(preds.device).T.unsqueeze(2) # each of size (batch, 1)
return {'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn, 'sup': sup}

def _shared_eval_epoch_end(self, outputs, prefix):
# gather across sub batches
tp = torch.cat([output['tp'] for output in outputs], dim=0)
fp = torch.cat([output['fp'] for output in outputs], dim=0)
tn = torch.cat([output['tn'] for output in outputs], dim=0)
fn = torch.cat([output['fn'] for output in outputs], dim=0)

# gather results across gpus
synchronize()
tp = torch.cat(all_gather(tp), dim=0)
fp = torch.cat(all_gather(fp), dim=0)
tn = torch.cat(all_gather(tn), dim=0)
fn = torch.cat(all_gather(fn), dim=0)

# calculate metrics
ba = balanced_accuracy(tp, fp, tn, fn)

logs = {f'{prefix}_balanced_accuracy': ba}
return {f'{prefix}_balanced_accuracy': ba, 'log': logs}

def validation_step(self, batch, batch_idx):
return self._shared_eval_step(batch, batch_idx, prefix='val')

def test_step(self, batch, batch_idx):
return self._shared_eval_step(batch, batch_idx, prefix='test')

def validation_epoch_end(self, outputs):
return self._shared_eval_epoch_end(outputs, prefix='val')

def test_epoch_end(self, outputs):
return self._shared_eval_epoch_end(outputs, prefix='test')

def configure_optimizers(self):
optimizer = hydra.utils.instantiate(self.hparams.opt, params=self.parameters())
scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer)
return [optimizer], [scheduler]

def _get_dataset(self, partition):
if hasattr(self, f'{partition}_set') and getattr(self, f'{partition}_set') is not None:
return getattr(self, f'{partition}_set')
transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.hparams.transforms[partition]])
if 'normalisation' in self.hparams.transforms.keys():
normalisation = self.hparams.transforms['normalisation']
else:
normalisation = None

return DepthDataset(
data_dir=self.data_dir,
io=DepthMapIO,
data_label=self.hparams.data.data_label,
target_label=self.hparams.data.target_label,
task=self.task,
partition=partition,
transform=transform,
max_loaded_files=self.hparams.data.max_loaded_files,
normalisation=normalisation
)

def _get_dataloader(self, partition):
dataset = self._get_dataset(partition)
num_workers = self.hparams.data_loader[partition].num_workers
batch_size = get_batch_size(self.hparams.data_loader[partition].total_batch_size)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)

def setup(self, stage: str):
self.train_set = self._get_dataset('train') if stage == 'fit' else None
self.val_set = self._get_dataset('val') if stage == 'fit' else None
self.test_set = self._get_dataset('test') if stage == 'test' else None

def train_dataloader(self):
return self._get_dataloader('train')

def val_dataloader(self):
return self._get_dataloader('val')

def test_dataloader(self):
return self._get_dataloader('val') # FIXME
21 changes: 21 additions & 0 deletions sharpf/modeling/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch


def balanced_accuracy(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor:
"""
Calculate balanced accuracy for one class based on provided statistics

Args:
tp (Tensor): of shape (B, 1). True positive values.
fp (Tensor): of shape (B, 1). False positive values.
tn (Tensor): of shape (B, 1). True negative values.
fn (Tensor): of shape (B, 1). False negative values.

Returns:
torch.Tensor: balanced accuracy value
"""
tpr = tp / (tp + fn) # (B, 1)
tnr = tn / (tn + fp) # (B, 1)
tpr = torch.where(torch.isnan(tpr), tnr, tpr) # (B, 1)
tnr = torch.where(torch.isnan(tnr), tpr, tnr) # (B, 1)
return 0.5 * torch.mean(tpr + tnr)
22 changes: 22 additions & 0 deletions sharpf/modeling/model/pixel_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,25 @@ def from_config(cls, cfg: DictConfig):
"feature_extractor": hydra.utils.instantiate(cfg.feature_extractor),
"regression_head": nn.Sequential(*[hydra.utils.instantiate(node) for node in cfg.regression_head])
}

# still believe that single class or abstract class for pixel-task model would be better
@MODEL_REGISTRY.register()
class PixelSegmentator(nn.Module):
@configurable
def __init__(self, feature_extractor, segmentation_head):
super().__init__()
self.feature_extractor = feature_extractor
self.segmentation_head = segmentation_head

def initialize(self):
initialize_head(self.segmentation_head)

def forward(self, x):
return self.segmentation_head(self.feature_extractor(x))

@classmethod
def from_config(cls, cfg: DictConfig):
return {
"feature_extractor": hydra.utils.instantiate(cfg.feature_extractor),
"segmentation_head": nn.Sequential(*[hydra.utils.instantiate(node) for node in cfg.segmentation_head])
}
2 changes: 1 addition & 1 deletion sharpf/utils/abc_utils/hdf5/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .dataset import Hdf5File, LotsOfHdf5Files
from .dataset import Hdf5File, LotsOfHdf5Files, DepthDataset
85 changes: 85 additions & 0 deletions sharpf/utils/abc_utils/hdf5/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

log = logging.getLogger(__name__)

high_res_quantile = 7.4776

class Hdf5File(Dataset):
def __init__(self, filename, io, data_label=None, target_label=None, labels=None, preload=True,
Expand Down Expand Up @@ -141,3 +142,87 @@ def __getitem__(self, index):
file_index_to_unload = np.random.choice(loaded_file_indexes)
self.files[file_index_to_unload].unload()
return item

class DepthDataset(LotsOfHdf5Files):
Copy link
Collaborator

@rakhimovv rakhimovv Aug 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the point of inheritance if you redefine everything? almost all logic working with hdf5 I suppose is already well implemented inside LotsOfHdf5Files class by @artonson. I suppose DepthDataset should mostly concentrate on data and target preprocessing without diving into hdf5 reading logic


def __init__(self, io, data_dir, data_label, target_label, task, partition=None,
transform=None, normalisation=['quantile', 'standartize'], max_loaded_files=0):
super().__init__(data_dir=data_dir, io=io,
data_label=data_label, target_label=target_label,
labels=None,
partition=partition,
transform=transform,
max_loaded_files=max_loaded_files)
self.data_dir = data_dir
self.task = task
self.quality = self._get_quantity()
self.normalisation = normalisation

def _get_quantity(self):
data_dir_split = self.data_dir.split('_')
if 'high' in data_dir_split:
return 'high'
elif 'low' in data_dir_split:
return 'low'
elif 'med' in data_dir_split:
return 'med'

def quantile_normalize(self, data):
# mask -> min shift -> quantile

norm_data = np.copy(data)
mask_obj = np.where(norm_data != 0)
mask_back = np.where(norm_data == 0)
norm_data[mask_back] = norm_data.max() + 1.0 # new line
norm_data -= norm_data[mask_obj].min()

norm_data /= high_res_quantile

return norm_data

def standartize(self, data):
# zero mean, unit variance

standart_data = np.copy(data)
standart_data -= np.mean(standart_data)
std = np.linalg.norm(standart_data, axis=1).max()
if std > 0:
standart_data /= std

return standart_data

def __getitem__(self, index):

item = super().__getitem__(index)
data, target = item['image'], item['distances']
mask_1 = (np.copy(data) != 0.0).astype(float) # mask for object
mask_2 = np.where(data == 0) # mask for background

if 'quantile' in self.normalisation:
data = self.quantile_normalize(data)
if 'standartize' in self.normalisation:
data = self.standartize(data)

dist_new = np.copy(target)
dist_mask = dist_new * mask_1 # select object points
dist_mask[mask_2] = 1.0 # background points has max distance to sharp features
close_to_sharp = np.array((dist_mask != np.nan) & (dist_mask < 1.)).astype(float)

output = {}

if self.task == 'two-heads':
# regression + segmentation (or two-head network) has to targets:
# distance field and segmented close-to-sharp region of the object
target = torch.cat([torch.FloatTensor(dist_mask).unsqueeze(0), torch.FloatTensor(close_to_sharp).unsqueeze(0)], dim=0)
output['distance_and_close_to_sharp'] = target
if self.task == 'segmentation':
target = torch.FloatTensor(close_to_sharp).unsqueeze(0)
output['close_to_sharp_mask'] = target
elif self.task == 'regression':
target = torch.FloatTensor(dist_mask).unsqueeze(0)
output['distance_to_sharp'] = target

data = torch.FloatTensor(data).unsqueeze(0)
output['image'] = data

return output
Loading