-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
segmentation and depth ds #8
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
eef7ba4
segmentation and depth ds
bec69a8
corrected depth ds inheritance, metrics in segmentation model
bba6ff1
corrected depth ds inheritance
5cb47c9
added balanced accuracy and ~after run~ fixes
92b454b
added balanced accuracy for real
a75cb74
small fixes
bef0a33
commit before merge
c87b3c5
changed balanced accuracy computation, changed dict keys in depth ds,…
21f4417
Merge branch 'pl_hydra' into segmentation
rakhimovv 07c1a1e
delete garbage
rakhimovv 56e1d2e
add more distributed tools
rakhimovv 4b45ef4
fix metric calculation
rakhimovv 56d07cb
fix typos
rakhimovv aba17cc
fix dimensions
rakhimovv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# @package _group_ | ||
|
||
pl_class: | ||
_target_: sharpf.modeling.DepthSegmentator | ||
monitor: val_balanced_accuracy | ||
loss: | ||
_target_: torch.nn.functional.binary_cross_entropy_with_logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _group_ | ||
model_name: PixelSegmentator | ||
params: | ||
feature_extractor: | ||
_target_: sharpf.modeling.Unet | ||
params: | ||
encoder_name: resnet50 | ||
decoder_use_batchnorm: true | ||
decoder_channels: [256, 128, 64, 32, 16] | ||
decoder_attention_type: null | ||
in_channels: 1 | ||
segmentation_head: | ||
- _target_: torch.nn.Conv2d | ||
params: | ||
in_channels: 16 | ||
out_channels: 1 | ||
kernel_size: 3 | ||
padding: 1 | ||
- _target_: torch.nn.Sigmoid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# @package _group_ | ||
|
||
train: | ||
- _target_: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
val: | ||
- _target_: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
test: | ||
- _target_: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
normalisation: | ||
- data: ['standartize', 'quantile'] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .point_sharpness_regressor import PointSharpnessRegressor | ||
from .depth_regressor import DepthRegressor | ||
from .depth_segmentator import DepthSegmentator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import logging | ||
|
||
import hydra | ||
import torch | ||
import torch.nn as nn | ||
from pytorch_lightning import TrainResult | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.metrics.functional import stat_scores | ||
from torch.utils.data import DataLoader | ||
|
||
from sharpf.utils.comm import get_batch_size, all_gather, synchronize | ||
from ..metrics import balanced_accuracy | ||
from ..model.build import build_model | ||
from ...data import DepthMapIO | ||
from ...utils.abc_utils.hdf5 import DepthDataset | ||
from ...utils.abc_utils.torch import CompositeTransform | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class DepthSegmentator(LightningModule): | ||
|
||
def __init__(self, cfg): | ||
super().__init__() | ||
self.hparams = cfg | ||
self.task = 'segmentation' | ||
self.model = build_model(self.hparams.model) | ||
self.example_input_array = torch.rand(1, 1, 64, 64) | ||
self.data_dir = hydra.utils.to_absolute_path(self.hparams.data.data_dir) | ||
|
||
dist_backend = self.hparams.trainer.distributed_backend | ||
if (dist_backend is not None and 'ddp' in dist_backend) or ( | ||
dist_backend is None and self.hparams.trainer.gpus is not None and ( | ||
self.hparams.trainer.gpus > 1 or self.hparams.trainer.num_nodes > 1)): | ||
log.info('Converting BatchNorm to SyncBatchNorm. Do not forget other batch-dimension dependent operations.') | ||
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) | ||
|
||
def forward(self, x, as_mask=True): | ||
out = self.model(x) | ||
if as_mask: | ||
return (out.sigmoid() > 0.5).long() | ||
return self.model(x) | ||
|
||
def training_step(self, batch, batch_idx): | ||
points, target = batch['image'], batch['close_to_sharp_mask'] | ||
points = points.unsqueeze(1) if points.dim() == 3 else points | ||
preds = self.forward(points, as_mask=False) | ||
loss = hydra.utils.instantiate(self.hparams.meta_arch.loss, preds, target) | ||
result = TrainResult(minimize=loss) | ||
result.log('train_loss', loss, prog_bar=True) | ||
return result | ||
|
||
def _shared_eval_step(self, batch, batch_idx, prefix): | ||
points, target = batch['image'], batch['close_to_sharp_mask'] | ||
points = points.unsqueeze(1) if points.dim() == 3 else points | ||
preds = self.forward(points, as_mask=True) | ||
stats = [list(stat_scores(preds[i], target[i], class_index=1)) for i in range(preds.size(0))] | ||
tp, fp, tn, fn, sup = torch.Tensor(stats).to(preds.device).T.unsqueeze(2) # each of size (batch, 1) | ||
return {'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn, 'sup': sup} | ||
|
||
def _shared_eval_epoch_end(self, outputs, prefix): | ||
# gather across sub batches | ||
tp = torch.cat([output['tp'] for output in outputs], dim=0) | ||
fp = torch.cat([output['fp'] for output in outputs], dim=0) | ||
tn = torch.cat([output['tn'] for output in outputs], dim=0) | ||
fn = torch.cat([output['fn'] for output in outputs], dim=0) | ||
|
||
# gather results across gpus | ||
synchronize() | ||
tp = torch.cat(all_gather(tp), dim=0) | ||
fp = torch.cat(all_gather(fp), dim=0) | ||
tn = torch.cat(all_gather(tn), dim=0) | ||
fn = torch.cat(all_gather(fn), dim=0) | ||
|
||
# calculate metrics | ||
ba = balanced_accuracy(tp, fp, tn, fn) | ||
|
||
logs = {f'{prefix}_balanced_accuracy': ba} | ||
return {f'{prefix}_balanced_accuracy': ba, 'log': logs} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
return self._shared_eval_step(batch, batch_idx, prefix='val') | ||
|
||
def test_step(self, batch, batch_idx): | ||
return self._shared_eval_step(batch, batch_idx, prefix='test') | ||
|
||
def validation_epoch_end(self, outputs): | ||
return self._shared_eval_epoch_end(outputs, prefix='val') | ||
|
||
def test_epoch_end(self, outputs): | ||
return self._shared_eval_epoch_end(outputs, prefix='test') | ||
|
||
def configure_optimizers(self): | ||
optimizer = hydra.utils.instantiate(self.hparams.opt, params=self.parameters()) | ||
scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer) | ||
return [optimizer], [scheduler] | ||
|
||
def _get_dataset(self, partition): | ||
if hasattr(self, f'{partition}_set') and getattr(self, f'{partition}_set') is not None: | ||
return getattr(self, f'{partition}_set') | ||
transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.hparams.transforms[partition]]) | ||
if 'normalisation' in self.hparams.transforms.keys(): | ||
normalisation = self.hparams.transforms['normalisation'] | ||
else: | ||
normalisation = None | ||
|
||
return DepthDataset( | ||
data_dir=self.data_dir, | ||
io=DepthMapIO, | ||
data_label=self.hparams.data.data_label, | ||
target_label=self.hparams.data.target_label, | ||
task=self.task, | ||
partition=partition, | ||
transform=transform, | ||
max_loaded_files=self.hparams.data.max_loaded_files, | ||
normalisation=normalisation | ||
) | ||
|
||
def _get_dataloader(self, partition): | ||
dataset = self._get_dataset(partition) | ||
num_workers = self.hparams.data_loader[partition].num_workers | ||
batch_size = get_batch_size(self.hparams.data_loader[partition].total_batch_size) | ||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True) | ||
|
||
def setup(self, stage: str): | ||
self.train_set = self._get_dataset('train') if stage == 'fit' else None | ||
self.val_set = self._get_dataset('val') if stage == 'fit' else None | ||
self.test_set = self._get_dataset('test') if stage == 'test' else None | ||
|
||
def train_dataloader(self): | ||
return self._get_dataloader('train') | ||
|
||
def val_dataloader(self): | ||
return self._get_dataloader('val') | ||
|
||
def test_dataloader(self): | ||
return self._get_dataloader('val') # FIXME |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
|
||
|
||
def balanced_accuracy(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Calculate balanced accuracy for one class based on provided statistics | ||
|
||
Args: | ||
tp (Tensor): of shape (B, 1). True positive values. | ||
fp (Tensor): of shape (B, 1). False positive values. | ||
tn (Tensor): of shape (B, 1). True negative values. | ||
fn (Tensor): of shape (B, 1). False negative values. | ||
|
||
Returns: | ||
torch.Tensor: balanced accuracy value | ||
""" | ||
tpr = tp / (tp + fn) # (B, 1) | ||
tnr = tn / (tn + fp) # (B, 1) | ||
tpr = torch.where(torch.isnan(tpr), tnr, tpr) # (B, 1) | ||
tnr = torch.where(torch.isnan(tnr), tpr, tnr) # (B, 1) | ||
return 0.5 * torch.mean(tpr + tnr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .dataset import Hdf5File, LotsOfHdf5Files | ||
from .dataset import Hdf5File, LotsOfHdf5Files, DepthDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the point of inheritance if you redefine everything? almost all logic working with hdf5 I suppose is already well implemented inside
LotsOfHdf5Files
class by @artonson. I suppose DepthDataset should mostly concentrate on data and target preprocessing without diving into hdf5 reading logic