-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
segmentation and depth ds #8
Changes from 3 commits
eef7ba4
bec69a8
bba6ff1
5cb47c9
92b454b
a75cb74
bef0a33
c87b3c5
21f4417
07c1a1e
56e1d2e
4b45ef4
56d07cb
aba17cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# @package _group_ | ||
# | ||
# target: sharpf.modeling.DepthSegmentator | ||
# monitor: val_balanced_accuracy #TODO: add balanced accuracy metric class | ||
# loss: | ||
# target: torch.nn.BCEWithLogitsLoss |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _group_ | ||
model_name: PixelSegmentator | ||
params: | ||
feature_extractor: | ||
target: sharpf.modeling.Unet | ||
params: | ||
encoder_name: resnet50 | ||
decoder_use_batchnorm: true | ||
decoder_channels: [256, 128, 64, 32, 16] | ||
decoder_attention_type: null | ||
in_channels: 1 | ||
segmentation_head: | ||
- target: torch.nn.Conv2d | ||
params: | ||
in_channels: 16 | ||
out_channels: 1 | ||
kernel_size: 3 | ||
padding: 1 | ||
- target: torch.nn.Sigmoid |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# @package _group_ | ||
|
||
train: | ||
- target: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
val: | ||
- target: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
test: | ||
- target: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
normalisation: | ||
- data: ['standartize', 'quantile'] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import logging | ||
|
||
import hydra | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from pytorch_lightning import TrainResult | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.metrics import tensor_metric | ||
from torch.utils.data import DataLoader | ||
|
||
from sharpf.utils.comm import get_batch_size | ||
from sharpf.utils.losses import balanced_accuracy | ||
from ..model.build import build_model | ||
from ...data import DepthMapIO | ||
from ...utils.abc_utils import LotsOfHdf5Files, DepthDataset | ||
from ...utils.abc_utils.torch import CompositeTransform | ||
from ...utils.config import flatten_omegaconf | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@tensor_metric() | ||
def gather_sum(x: torch.Tensor) -> torch.Tensor: | ||
return x | ||
|
||
|
||
class DepthSegmentator(LightningModule): | ||
|
||
def __init__(self, cfg): | ||
super().__init__() | ||
self.hparams = flatten_omegaconf(cfg) # there should be better official way later | ||
self.cfg = cfg | ||
self.task = 'segmentation' | ||
self.model = build_model(cfg.model) | ||
self.example_input_array = torch.rand(1, 1, 64, 64) | ||
self.data_dir = hydra.utils.to_absolute_path(self.cfg.data.data_dir) | ||
|
||
dist_backend = self.cfg.trainer.distributed_backend | ||
if (dist_backend is not None and 'ddp' in dist_backend) or ( | ||
dist_backend is None and self.cfg.trainer.gpus is not None and ( | ||
self.cfg.trainer.gpus > 1 or self.cfg.trainer.num_nodes > 1)): | ||
log.info('Converting BatchNorm to SyncBatchNorm. Do not forget other batch-dimension dependent operations.') | ||
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
def training_step(self, batch, batch_idx): | ||
points, distances = batch['image'], batch['distances'] | ||
points = points.unsqueeze(1) if points.dim() == 3 else points | ||
preds = self.forward(points) | ||
loss = hydra.utils.instantiate(self.cfg.meta_arch.loss, preds, distances) | ||
result = TrainResult(minimize=loss) | ||
result.log('train_loss', loss, prog_bar=True) | ||
return result | ||
|
||
def _shared_eval_step(self, batch, batch_idx, prefix): | ||
metric_name = 'balanced_accuracy' | ||
metric = balanced_accuracy | ||
points, distances = batch['image'], batch['distances'] | ||
points = points.unsqueeze(1) if points.dim() == 3 else points | ||
preds = self.forward(points) | ||
|
||
metric_value = metric(preds, distances) # (batch) | ||
# loss = hydra.utils.instantiate(self.cfg.meta_arch.loss, preds, distances) | ||
# self.logger[0].experiment.add_scalars('losses', {f'{prefix}_loss': loss}) | ||
# TODO Consider pl.EvalResult, once there are good examples how to use it | ||
return {f'{metric_name}_sum': metric_value.sum(), | ||
'batch_size': torch.tensor(points.size(0), device=self.device)} | ||
|
||
def _shared_eval_epoch_end(self, outputs, prefix): | ||
metric_name = 'balanced_accuracy' | ||
metric_sum = 0 | ||
size = 0 | ||
for output in outputs: | ||
metric_sum += output[f'{metric_name}_sum'] | ||
size += output['batch_size'] | ||
mean_metric = gather_sum(metric_sum) / gather_sum(size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the balanced accuracy needs a bit different gathering logic across batches |
||
logs = {f'{prefix}_mean_{metric_name}': mean_metric} | ||
return {f'{prefix}_mean_{metric_name}': mean_metric, 'log': logs} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
return self._shared_eval_step(batch, batch_idx, prefix='val') | ||
|
||
def test_step(self, batch, batch_idx): | ||
return self._shared_eval_step(batch, batch_idx, prefix='test') | ||
|
||
def validation_epoch_end(self, outputs): | ||
return self._shared_eval_epoch_end(outputs, prefix='val') | ||
|
||
def test_epoch_end(self, outputs): | ||
return self._shared_eval_epoch_end(outputs, prefix='test') | ||
|
||
def configure_optimizers(self): | ||
optimizer = hydra.utils.instantiate(self.cfg.opt, params=self.parameters()) | ||
scheduler = hydra.utils.instantiate(self.cfg.scheduler, optimizer=optimizer) | ||
return [optimizer], [scheduler] | ||
|
||
def _get_dataset(self, partition): | ||
if hasattr(self, f'{partition}_set') and getattr(self, f'{partition}_set') is not None: | ||
return getattr(self, f'{partition}_set') | ||
transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.cfg.transforms[partition]]) | ||
if 'normalisation' in self.cfg.transforms.keys: | ||
normalisation = self.cfg.transforms['normalisation'] | ||
else: | ||
normalisation = None | ||
|
||
return DepthDataset( | ||
data_dir=self.data_dir, | ||
io=DepthMapIO, | ||
data_label=self.cfg.data.data_label, | ||
target_label=self.cfg.data.target_label, | ||
task=self.task, | ||
partition=partition, | ||
transform=transform, | ||
max_loaded_files=self.cfg.data.max_loaded_files, | ||
normalisation=normalisation | ||
) | ||
|
||
def _get_dataloader(self, partition): | ||
dataset = self._get_dataset(partition) | ||
num_workers = self.cfg.data_loader[partition].num_workers | ||
batch_size = get_batch_size(self.cfg.data_loader[partition].total_batch_size) | ||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True) | ||
|
||
def setup(self, stage: str): | ||
self.train_set = self._get_dataset('train') if stage == 'fit' else None | ||
self.val_set = self._get_dataset('val') if stage == 'fit' else None | ||
self.test_set = self._get_dataset('test') if stage == 'test' else None | ||
|
||
def train_dataloader(self): | ||
return self._get_dataloader('train') | ||
|
||
def val_dataloader(self): | ||
return self._get_dataloader('val') | ||
|
||
def test_dataloader(self): | ||
return self._get_dataloader('val') # FIXME |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,3 +133,79 @@ def __getitem__(self, index): | |
file_index_to_unload = np.random.choice(loaded_file_indexes) | ||
self.files[file_index_to_unload].unload() | ||
return item | ||
|
||
class DepthDataset(LotsOfHdf5Files): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the point of inheritance if you redefine everything? almost all logic working with hdf5 I suppose is already well implemented inside |
||
|
||
def __init__(self, io, data_dir, data_label, target_label, task, partition=None, | ||
transform=None, normalisation=['quantile', 'standartize'], max_loaded_files=0): | ||
super().__init__(data_dir=data_dir, io=io, | ||
data_label=data_label, target_label=target_label, | ||
labels=None, | ||
partition=partition, | ||
transform=transform, | ||
max_loaded_files=max_loaded_files) | ||
self.task = task | ||
self.quality = self._get_quantity() | ||
self.normalisation = normalisation | ||
|
||
def _get_quantity(self): | ||
data_dir_split = self.data_dir.split('_') | ||
if 'high' in data_dir_split: | ||
return 'high' | ||
elif 'low' in data_dir_split: | ||
return 'low' | ||
elif 'med' in data_dir_split: | ||
return 'med' | ||
|
||
def quantile_normalize(self, data): | ||
# mask -> min shift -> quantile | ||
|
||
norm_data = np.copy(data) | ||
mask_obj = np.where(norm_data != 0) | ||
mask_back = np.where(norm_data == 0) | ||
norm_data[mask_back] = norm_data.max() + 1.0 # new line | ||
norm_data -= norm_data[mask_obj].min() | ||
|
||
norm_data /= high_res_quantile | ||
|
||
return norm_data | ||
|
||
def standartize(self, data): | ||
# zero mean, unit variance | ||
|
||
standart_data = np.copy(data) | ||
standart_data -= np.mean(standart_data) | ||
std = np.linalg.norm(standart_data, axis=1).max() | ||
if std > 0: | ||
standart_data /= std | ||
|
||
return standart_data | ||
|
||
def __getitem__(self, index): | ||
|
||
data, target = super.__getitem__(index) | ||
mask_1 = (np.copy(data) != 0.0).astype(float) # mask for object | ||
mask_2 = np.where(data == 0) # mask for background | ||
|
||
if 'quantile' in self.normalisation: | ||
data = self.quantile_normalize(data) | ||
if 'standartize' in self.normalisation: | ||
data = self.standartize(data) | ||
|
||
dist_new = np.copy(target) | ||
dist_mask = dist_new * mask_1 # select object points | ||
dist_mask[mask_2] = 1.0 # background points has max distance to sharp features | ||
close_to_sharp = np.array((dist_mask != np.nan) & (dist_mask < 1.)).astype(float) | ||
|
||
if self.task == 'two-heads': | ||
# regression + segmentation (or two-head network) has to targets: | ||
# distance field and segmented close-to-sharp region of the object | ||
target = torch.cat([torch.FloatTensor(dist_mask).unsqueeze(0), torch.FloatTensor(close_to_sharp).unsqueeze(0)], dim=0) | ||
if self.task == 'segmentation': | ||
target = torch.FloatTensor(close_to_sharp) | ||
elif self.task == 'regression': | ||
target = torch.FloatTensor(dist_mask) | ||
|
||
data = torch.FloatTensor(data).unsqueeze(0) | ||
|
||
return {'data': data, 'target': target} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant something more informative :) like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so then keys should differ for each task? what's the purpose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or they should be similar for each task in depth dataset, just have names which describe what kind of data is this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for the sake of readability There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and also for example, if you want to add later several new additional targets, it would be much easier to add them just like new keys: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, for segmentation task for ex. I have a binary close-to-sharp target and for regression a distance field, then there should be several returns for each task or same names There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't found the implementation of
balanced_accuracy