-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
segmentation and depth ds #8
Changes from 4 commits
eef7ba4
bec69a8
bba6ff1
5cb47c9
92b454b
a75cb74
bef0a33
c87b3c5
21f4417
07c1a1e
56e1d2e
4b45ef4
56d07cb
aba17cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# @package _group_ | ||
|
||
target: sharpf.modeling.DepthSegmentator | ||
monitor: val_balanced_accuracy #TODO: add balanced accuracy metric class | ||
loss: | ||
target: torch.nn.functional.binary_cross_entropy_with_logits |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _group_ | ||
model_name: PixelSegmentator | ||
params: | ||
feature_extractor: | ||
target: sharpf.modeling.Unet | ||
params: | ||
encoder_name: resnet50 | ||
decoder_use_batchnorm: true | ||
decoder_channels: [256, 128, 64, 32, 16] | ||
decoder_attention_type: null | ||
in_channels: 1 | ||
segmentation_head: | ||
- target: torch.nn.Conv2d | ||
params: | ||
in_channels: 16 | ||
out_channels: 1 | ||
kernel_size: 3 | ||
padding: 1 | ||
- target: torch.nn.Sigmoid |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# @package _group_ | ||
|
||
train: | ||
- target: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
val: | ||
- target: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
test: | ||
- target: sharpf.utils.abc_utils.torch.TypeCast | ||
|
||
normalisation: | ||
- data: ['standartize', 'quantile'] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import logging | ||
|
||
import hydra | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from pytorch_lightning import TrainResult | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.metrics import tensor_metric | ||
from torch.utils.data import DataLoader | ||
|
||
from sharpf.utils.comm import get_batch_size | ||
from sharpf.utils.losses import balanced_accuracy | ||
from ..model.build import build_model | ||
from ...data import DepthMapIO | ||
from ...utils.abc_utils.hdf5 import DepthDataset | ||
from ...utils.abc_utils.torch import CompositeTransform | ||
from ...utils.config import flatten_omegaconf | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@tensor_metric() | ||
def gather_sum(x: torch.Tensor) -> torch.Tensor: | ||
return x | ||
|
||
|
||
class DepthSegmentator(LightningModule): | ||
|
||
def __init__(self, cfg): | ||
super().__init__() | ||
self.hparams = flatten_omegaconf(cfg) # there should be better official way later | ||
self.cfg = cfg | ||
self.task = 'segmentation' | ||
self.model = build_model(cfg.model) | ||
self.example_input_array = torch.rand(1, 1, 64, 64) | ||
self.data_dir = hydra.utils.to_absolute_path(self.cfg.data.data_dir) | ||
|
||
dist_backend = self.cfg.trainer.distributed_backend | ||
if (dist_backend is not None and 'ddp' in dist_backend) or ( | ||
dist_backend is None and self.cfg.trainer.gpus is not None and ( | ||
self.cfg.trainer.gpus > 1 or self.cfg.trainer.num_nodes > 1)): | ||
log.info('Converting BatchNorm to SyncBatchNorm. Do not forget other batch-dimension dependent operations.') | ||
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
def training_step(self, batch, batch_idx): | ||
print('training step') | ||
points, distances = batch['image'], batch['distances'] | ||
points = points.unsqueeze(1) if points.dim() == 3 else points | ||
preds = self.forward(points) | ||
print(preds.shape, distances.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do not include please debugging output into the commit |
||
loss = hydra.utils.instantiate(self.cfg.meta_arch.loss, preds, distances) | ||
result = TrainResult(minimize=loss) | ||
result.log('train_loss', loss, prog_bar=True) | ||
return result | ||
|
||
def _shared_eval_step(self, batch, batch_idx, prefix): | ||
metric_name = 'balanced_accuracy' | ||
metric = balanced_accuracy | ||
points, distances = batch['image'], batch['distances'] | ||
points = points.unsqueeze(1) if points.dim() == 3 else points | ||
preds = self.forward(points) | ||
|
||
metric_value = metric(preds, distances) # (batch) | ||
# loss = hydra.utils.instantiate(self.cfg.meta_arch.loss, preds, distances) | ||
# self.logger[0].experiment.add_scalars('losses', {f'{prefix}_loss': loss}) | ||
# TODO Consider pl.EvalResult, once there are good examples how to use it | ||
return {f'{metric_name}_sum': metric_value.sum(), | ||
'batch_size': torch.tensor(points.size(0), device=self.device)} | ||
|
||
def _shared_eval_epoch_end(self, outputs, prefix): | ||
metric_name = 'balanced_accuracy' | ||
metric_sum = 0 | ||
size = 0 | ||
for output in outputs: | ||
metric_sum += output[f'{metric_name}_sum'] | ||
size += output['batch_size'] | ||
mean_metric = gather_sum(metric_sum) / gather_sum(size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the balanced accuracy needs a bit different gathering logic across batches |
||
logs = {f'{prefix}_mean_{metric_name}': mean_metric} | ||
return {f'{prefix}_mean_{metric_name}': mean_metric, 'log': logs} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
return self._shared_eval_step(batch, batch_idx, prefix='val') | ||
|
||
def test_step(self, batch, batch_idx): | ||
return self._shared_eval_step(batch, batch_idx, prefix='test') | ||
|
||
def validation_epoch_end(self, outputs): | ||
return self._shared_eval_epoch_end(outputs, prefix='val') | ||
|
||
def test_epoch_end(self, outputs): | ||
return self._shared_eval_epoch_end(outputs, prefix='test') | ||
|
||
def configure_optimizers(self): | ||
optimizer = hydra.utils.instantiate(self.cfg.opt, params=self.parameters()) | ||
scheduler = hydra.utils.instantiate(self.cfg.scheduler, optimizer=optimizer) | ||
return [optimizer], [scheduler] | ||
|
||
def _get_dataset(self, partition): | ||
if hasattr(self, f'{partition}_set') and getattr(self, f'{partition}_set') is not None: | ||
return getattr(self, f'{partition}_set') | ||
transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.cfg.transforms[partition]]) | ||
if 'normalisation' in self.cfg.transforms.keys: | ||
normalisation = self.cfg.transforms['normalisation'] | ||
else: | ||
normalisation = None | ||
|
||
return DepthDataset( | ||
data_dir=self.data_dir, | ||
io=DepthMapIO, | ||
data_label=self.cfg.data.data_label, | ||
target_label=self.cfg.data.target_label, | ||
task=self.task, | ||
partition=partition, | ||
transform=transform, | ||
max_loaded_files=self.cfg.data.max_loaded_files, | ||
normalisation=normalisation | ||
) | ||
|
||
def _get_dataloader(self, partition): | ||
dataset = self._get_dataset(partition) | ||
num_workers = self.cfg.data_loader[partition].num_workers | ||
batch_size = get_batch_size(self.cfg.data_loader[partition].total_batch_size) | ||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True) | ||
|
||
def setup(self, stage: str): | ||
self.train_set = self._get_dataset('train') if stage == 'fit' else None | ||
self.val_set = self._get_dataset('val') if stage == 'fit' else None | ||
self.test_set = self._get_dataset('test') if stage == 'test' else None | ||
|
||
def train_dataloader(self): | ||
return self._get_dataloader('train') | ||
|
||
def val_dataloader(self): | ||
return self._get_dataloader('val') | ||
|
||
def test_dataloader(self): | ||
return self._get_dataloader('val') # FIXME |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .dataset import Hdf5File, LotsOfHdf5Files | ||
from .dataset import Hdf5File, LotsOfHdf5Files, DepthDataset |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
|
||
log = logging.getLogger(__name__) | ||
|
||
high_res_quantile = 7.4776 | ||
|
||
class Hdf5File(Dataset): | ||
def __init__(self, filename, io, data_label=None, target_label=None, labels=None, preload=True, | ||
|
@@ -133,3 +134,81 @@ def __getitem__(self, index): | |
file_index_to_unload = np.random.choice(loaded_file_indexes) | ||
self.files[file_index_to_unload].unload() | ||
return item | ||
|
||
class DepthDataset(LotsOfHdf5Files): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the point of inheritance if you redefine everything? almost all logic working with hdf5 I suppose is already well implemented inside |
||
|
||
def __init__(self, io, data_dir, data_label, target_label, task, partition=None, | ||
transform=None, normalisation=['quantile', 'standartize'], max_loaded_files=0): | ||
super().__init__(data_dir=data_dir, io=io, | ||
data_label=data_label, target_label=target_label, | ||
labels=None, | ||
partition=partition, | ||
transform=transform, | ||
max_loaded_files=max_loaded_files) | ||
self.data_dir = data_dir | ||
self.task = task | ||
self.quality = self._get_quantity() | ||
self.normalisation = normalisation | ||
|
||
def _get_quantity(self): | ||
data_dir_split = self.data_dir.split('_') | ||
if 'high' in data_dir_split: | ||
return 'high' | ||
elif 'low' in data_dir_split: | ||
return 'low' | ||
elif 'med' in data_dir_split: | ||
return 'med' | ||
|
||
def quantile_normalize(self, data): | ||
# mask -> min shift -> quantile | ||
|
||
norm_data = np.copy(data) | ||
mask_obj = np.where(norm_data != 0) | ||
mask_back = np.where(norm_data == 0) | ||
norm_data[mask_back] = norm_data.max() + 1.0 # new line | ||
norm_data -= norm_data[mask_obj].min() | ||
|
||
norm_data /= high_res_quantile | ||
|
||
return norm_data | ||
|
||
def standartize(self, data): | ||
# zero mean, unit variance | ||
|
||
standart_data = np.copy(data) | ||
standart_data -= np.mean(standart_data) | ||
std = np.linalg.norm(standart_data, axis=1).max() | ||
if std > 0: | ||
standart_data /= std | ||
|
||
return standart_data | ||
|
||
def __getitem__(self, index): | ||
|
||
item = super().__getitem__(index) | ||
data, target = item['image'], item['distances'] | ||
mask_1 = (np.copy(data) != 0.0).astype(float) # mask for object | ||
mask_2 = np.where(data == 0) # mask for background | ||
|
||
if 'quantile' in self.normalisation: | ||
data = self.quantile_normalize(data) | ||
if 'standartize' in self.normalisation: | ||
data = self.standartize(data) | ||
|
||
dist_new = np.copy(target) | ||
dist_mask = dist_new * mask_1 # select object points | ||
dist_mask[mask_2] = 1.0 # background points has max distance to sharp features | ||
close_to_sharp = np.array((dist_mask != np.nan) & (dist_mask < 1.)).astype(float) | ||
|
||
if self.task == 'two-heads': | ||
# regression + segmentation (or two-head network) has to targets: | ||
# distance field and segmented close-to-sharp region of the object | ||
target = torch.cat([torch.FloatTensor(dist_mask).unsqueeze(0), torch.FloatTensor(close_to_sharp).unsqueeze(0)], dim=0) | ||
if self.task == 'segmentation': | ||
target = torch.FloatTensor(close_to_sharp).unsqueeze(0) | ||
elif self.task == 'regression': | ||
target = torch.FloatTensor(dist_mask).unsqueeze(0) | ||
|
||
data = torch.FloatTensor(data).unsqueeze(0) | ||
|
||
return {'image': data, 'distances': target} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
|
||
from sharpf.utils.callbacks import FitDurationCallback | ||
from sharpf.utils.collect_env import collect_env_info | ||
from sharpf.modeling.meta_arch.depth_regressor import DepthRegressor | ||
from sharpf.modeling.meta_arch.depth_segmentator import DepthSegmentator | ||
|
||
from configs import trainer, optimizer, scheduler | ||
|
||
|
@@ -36,7 +38,7 @@ def main(cfg: DictConfig): | |
log.info(f"Original working directory: {hydra.utils.get_original_cwd()}") | ||
seed_everything(cfg.seed) | ||
|
||
model = instantiate(cfg.meta_arch, cfg=cfg) | ||
model = DepthRegressor(cfg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if cfg.weights is not None: | ||
model.load_state_dict(torch.load(cfg.weights)['state_dict']) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't found the implementation of
balanced_accuracy