diff --git a/configs/meta_arch/depth-segmentator.yaml b/configs/meta_arch/depth-segmentator.yaml new file mode 100644 index 00000000..53d3e914 --- /dev/null +++ b/configs/meta_arch/depth-segmentator.yaml @@ -0,0 +1,7 @@ +# @package _group_ + +pl_class: + _target_: sharpf.modeling.DepthSegmentator +monitor: val_balanced_accuracy +loss: + _target_: torch.nn.functional.binary_cross_entropy_with_logits diff --git a/configs/model/unet-segmentator.yaml b/configs/model/unet-segmentator.yaml new file mode 100644 index 00000000..90d383eb --- /dev/null +++ b/configs/model/unet-segmentator.yaml @@ -0,0 +1,19 @@ +# @package _group_ +model_name: PixelSegmentator +params: + feature_extractor: + _target_: sharpf.modeling.Unet + params: + encoder_name: resnet50 + decoder_use_batchnorm: true + decoder_channels: [256, 128, 64, 32, 16] + decoder_attention_type: null + in_channels: 1 + segmentation_head: + - _target_: torch.nn.Conv2d + params: + in_channels: 16 + out_channels: 1 + kernel_size: 3 + padding: 1 + - _target_: torch.nn.Sigmoid \ No newline at end of file diff --git a/configs/transforms/depth-norm.yaml b/configs/transforms/depth-norm.yaml new file mode 100644 index 00000000..09ee240b --- /dev/null +++ b/configs/transforms/depth-norm.yaml @@ -0,0 +1,15 @@ +# @package _group_ + +train: + - _target_: sharpf.utils.abc_utils.torch.TypeCast + +val: + - _target_: sharpf.utils.abc_utils.torch.TypeCast + +test: + - _target_: sharpf.utils.abc_utils.torch.TypeCast + +normalisation: + - data: ['standartize', 'quantile'] + + diff --git a/sharpf/modeling/__init__.py b/sharpf/modeling/__init__.py index 22232382..205828c4 100644 --- a/sharpf/modeling/__init__.py +++ b/sharpf/modeling/__init__.py @@ -1,4 +1,5 @@ -from .meta_arch import PointSharpnessRegressor, DepthRegressor +from .meta_arch import PointSharpnessRegressor, DepthRegressor, DepthSegmentator +from .metrics import balanced_accuracy from .model import MODEL_REGISTRY, build_model, DGCNN, Unet, PixelRegressor from .modules import ( AggregationMax, diff --git a/sharpf/modeling/meta_arch/__init__.py b/sharpf/modeling/meta_arch/__init__.py index 7840ba94..1aac6294 100644 --- a/sharpf/modeling/meta_arch/__init__.py +++ b/sharpf/modeling/meta_arch/__init__.py @@ -1,2 +1,3 @@ from .point_sharpness_regressor import PointSharpnessRegressor from .depth_regressor import DepthRegressor +from .depth_segmentator import DepthSegmentator diff --git a/sharpf/modeling/meta_arch/depth_regressor.py b/sharpf/modeling/meta_arch/depth_regressor.py index c449b623..79af283c 100644 --- a/sharpf/modeling/meta_arch/depth_regressor.py +++ b/sharpf/modeling/meta_arch/depth_regressor.py @@ -12,7 +12,7 @@ from sharpf.utils.comm import get_batch_size from ..model.build import build_model from ...data import DepthMapIO -from ...utils.abc_utils import LotsOfHdf5Files +from ...utils.abc_utils.hdf5.dataset import LotsOfHdf5Files, DepthDataset from ...utils.abc_utils.torch import CompositeTransform log = logging.getLogger(__name__) @@ -27,7 +27,8 @@ class DepthRegressor(LightningModule): def __init__(self, cfg): super().__init__() - self.hparams = cfg + self.hparams = cfg # there should be better official way later + self.task = 'regression' self.model = build_model(self.hparams.model) self.example_input_array = torch.rand(1, 1, 64, 64) self.data_dir = hydra.utils.to_absolute_path(self.hparams.data.data_dir) @@ -43,7 +44,7 @@ def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): - points, distances = batch['image'], batch['distances'] + points, distances = batch['image'], batch['distance_to_sharp'] points = points.unsqueeze(1) if points.dim() == 3 else points preds = self.forward(points) loss = hydra.utils.instantiate(self.hparams.meta_arch.loss, preds, distances) @@ -52,7 +53,7 @@ def training_step(self, batch, batch_idx): return result def _shared_eval_step(self, batch, batch_idx, prefix): - points, distances = batch['image'], batch['distances'] + points, distances = batch['image'], batch['distance_to_sharp'] points = points.unsqueeze(1) if points.dim() == 3 else points preds = self.forward(points) @@ -94,12 +95,14 @@ def configure_optimizers(self): def _get_dataset(self, partition): if hasattr(self, f'{partition}_set') and getattr(self, f'{partition}_set') is not None: return getattr(self, f'{partition}_set') - transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.hparams.transforms[partition]]) - return LotsOfHdf5Files( + + transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.cfg.transforms[partition]]) + return DepthDataset( data_dir=self.data_dir, io=DepthMapIO, - data_label=self.hparams.data.data_label, - target_label=self.hparams.data.target_label, + data_label=self.cfg.data.data_label, + target_label=self.cfg.data.target_label, + task=self.task, partition=partition, transform=transform, max_loaded_files=self.hparams.data.max_loaded_files diff --git a/sharpf/modeling/meta_arch/depth_segmentator.py b/sharpf/modeling/meta_arch/depth_segmentator.py new file mode 100644 index 00000000..86f9c25d --- /dev/null +++ b/sharpf/modeling/meta_arch/depth_segmentator.py @@ -0,0 +1,137 @@ +import logging + +import hydra +import torch +import torch.nn as nn +from pytorch_lightning import TrainResult +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.metrics.functional import stat_scores +from torch.utils.data import DataLoader + +from sharpf.utils.comm import get_batch_size, all_gather, synchronize +from ..metrics import balanced_accuracy +from ..model.build import build_model +from ...data import DepthMapIO +from ...utils.abc_utils.hdf5 import DepthDataset +from ...utils.abc_utils.torch import CompositeTransform + +log = logging.getLogger(__name__) + + +class DepthSegmentator(LightningModule): + + def __init__(self, cfg): + super().__init__() + self.hparams = cfg + self.task = 'segmentation' + self.model = build_model(self.hparams.model) + self.example_input_array = torch.rand(1, 1, 64, 64) + self.data_dir = hydra.utils.to_absolute_path(self.hparams.data.data_dir) + + dist_backend = self.hparams.trainer.distributed_backend + if (dist_backend is not None and 'ddp' in dist_backend) or ( + dist_backend is None and self.hparams.trainer.gpus is not None and ( + self.hparams.trainer.gpus > 1 or self.hparams.trainer.num_nodes > 1)): + log.info('Converting BatchNorm to SyncBatchNorm. Do not forget other batch-dimension dependent operations.') + self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) + + def forward(self, x, as_mask=True): + out = self.model(x) + if as_mask: + return (out.sigmoid() > 0.5).long() + return self.model(x) + + def training_step(self, batch, batch_idx): + points, target = batch['image'], batch['close_to_sharp_mask'] + points = points.unsqueeze(1) if points.dim() == 3 else points + preds = self.forward(points, as_mask=False) + loss = hydra.utils.instantiate(self.hparams.meta_arch.loss, preds, target) + result = TrainResult(minimize=loss) + result.log('train_loss', loss, prog_bar=True) + return result + + def _shared_eval_step(self, batch, batch_idx, prefix): + points, target = batch['image'], batch['close_to_sharp_mask'] + points = points.unsqueeze(1) if points.dim() == 3 else points + preds = self.forward(points, as_mask=True) + stats = [list(stat_scores(preds[i], target[i], class_index=1)) for i in range(preds.size(0))] + tp, fp, tn, fn, sup = torch.Tensor(stats).to(preds.device).T.unsqueeze(2) # each of size (batch, 1) + return {'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn, 'sup': sup} + + def _shared_eval_epoch_end(self, outputs, prefix): + # gather across sub batches + tp = torch.cat([output['tp'] for output in outputs], dim=0) + fp = torch.cat([output['fp'] for output in outputs], dim=0) + tn = torch.cat([output['tn'] for output in outputs], dim=0) + fn = torch.cat([output['fn'] for output in outputs], dim=0) + + # gather results across gpus + synchronize() + tp = torch.cat(all_gather(tp), dim=0) + fp = torch.cat(all_gather(fp), dim=0) + tn = torch.cat(all_gather(tn), dim=0) + fn = torch.cat(all_gather(fn), dim=0) + + # calculate metrics + ba = balanced_accuracy(tp, fp, tn, fn) + + logs = {f'{prefix}_balanced_accuracy': ba} + return {f'{prefix}_balanced_accuracy': ba, 'log': logs} + + def validation_step(self, batch, batch_idx): + return self._shared_eval_step(batch, batch_idx, prefix='val') + + def test_step(self, batch, batch_idx): + return self._shared_eval_step(batch, batch_idx, prefix='test') + + def validation_epoch_end(self, outputs): + return self._shared_eval_epoch_end(outputs, prefix='val') + + def test_epoch_end(self, outputs): + return self._shared_eval_epoch_end(outputs, prefix='test') + + def configure_optimizers(self): + optimizer = hydra.utils.instantiate(self.hparams.opt, params=self.parameters()) + scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer) + return [optimizer], [scheduler] + + def _get_dataset(self, partition): + if hasattr(self, f'{partition}_set') and getattr(self, f'{partition}_set') is not None: + return getattr(self, f'{partition}_set') + transform = CompositeTransform([hydra.utils.instantiate(tf) for tf in self.hparams.transforms[partition]]) + if 'normalisation' in self.hparams.transforms.keys(): + normalisation = self.hparams.transforms['normalisation'] + else: + normalisation = None + + return DepthDataset( + data_dir=self.data_dir, + io=DepthMapIO, + data_label=self.hparams.data.data_label, + target_label=self.hparams.data.target_label, + task=self.task, + partition=partition, + transform=transform, + max_loaded_files=self.hparams.data.max_loaded_files, + normalisation=normalisation + ) + + def _get_dataloader(self, partition): + dataset = self._get_dataset(partition) + num_workers = self.hparams.data_loader[partition].num_workers + batch_size = get_batch_size(self.hparams.data_loader[partition].total_batch_size) + return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True) + + def setup(self, stage: str): + self.train_set = self._get_dataset('train') if stage == 'fit' else None + self.val_set = self._get_dataset('val') if stage == 'fit' else None + self.test_set = self._get_dataset('test') if stage == 'test' else None + + def train_dataloader(self): + return self._get_dataloader('train') + + def val_dataloader(self): + return self._get_dataloader('val') + + def test_dataloader(self): + return self._get_dataloader('val') # FIXME diff --git a/sharpf/modeling/metrics.py b/sharpf/modeling/metrics.py new file mode 100644 index 00000000..666bc942 --- /dev/null +++ b/sharpf/modeling/metrics.py @@ -0,0 +1,21 @@ +import torch + + +def balanced_accuracy(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor: + """ + Calculate balanced accuracy for one class based on provided statistics + + Args: + tp (Tensor): of shape (B, 1). True positive values. + fp (Tensor): of shape (B, 1). False positive values. + tn (Tensor): of shape (B, 1). True negative values. + fn (Tensor): of shape (B, 1). False negative values. + + Returns: + torch.Tensor: balanced accuracy value + """ + tpr = tp / (tp + fn) # (B, 1) + tnr = tn / (tn + fp) # (B, 1) + tpr = torch.where(torch.isnan(tpr), tnr, tpr) # (B, 1) + tnr = torch.where(torch.isnan(tnr), tpr, tnr) # (B, 1) + return 0.5 * torch.mean(tpr + tnr) diff --git a/sharpf/modeling/model/pixel_regressor.py b/sharpf/modeling/model/pixel_regressor.py index 269e72ef..6de62f41 100644 --- a/sharpf/modeling/model/pixel_regressor.py +++ b/sharpf/modeling/model/pixel_regressor.py @@ -27,3 +27,25 @@ def from_config(cls, cfg: DictConfig): "feature_extractor": hydra.utils.instantiate(cfg.feature_extractor), "regression_head": nn.Sequential(*[hydra.utils.instantiate(node) for node in cfg.regression_head]) } + +# still believe that single class or abstract class for pixel-task model would be better +@MODEL_REGISTRY.register() +class PixelSegmentator(nn.Module): + @configurable + def __init__(self, feature_extractor, segmentation_head): + super().__init__() + self.feature_extractor = feature_extractor + self.segmentation_head = segmentation_head + + def initialize(self): + initialize_head(self.segmentation_head) + + def forward(self, x): + return self.segmentation_head(self.feature_extractor(x)) + + @classmethod + def from_config(cls, cfg: DictConfig): + return { + "feature_extractor": hydra.utils.instantiate(cfg.feature_extractor), + "segmentation_head": nn.Sequential(*[hydra.utils.instantiate(node) for node in cfg.segmentation_head]) + } diff --git a/sharpf/utils/abc_utils/hdf5/__init__.py b/sharpf/utils/abc_utils/hdf5/__init__.py index 7d90035e..02555e6f 100644 --- a/sharpf/utils/abc_utils/hdf5/__init__.py +++ b/sharpf/utils/abc_utils/hdf5/__init__.py @@ -1 +1 @@ -from .dataset import Hdf5File, LotsOfHdf5Files +from .dataset import Hdf5File, LotsOfHdf5Files, DepthDataset diff --git a/sharpf/utils/abc_utils/hdf5/dataset.py b/sharpf/utils/abc_utils/hdf5/dataset.py index 7ab917b4..0fe86b7f 100644 --- a/sharpf/utils/abc_utils/hdf5/dataset.py +++ b/sharpf/utils/abc_utils/hdf5/dataset.py @@ -11,6 +11,7 @@ log = logging.getLogger(__name__) +high_res_quantile = 7.4776 class Hdf5File(Dataset): def __init__(self, filename, io, data_label=None, target_label=None, labels=None, preload=True, @@ -141,3 +142,87 @@ def __getitem__(self, index): file_index_to_unload = np.random.choice(loaded_file_indexes) self.files[file_index_to_unload].unload() return item + +class DepthDataset(LotsOfHdf5Files): + + def __init__(self, io, data_dir, data_label, target_label, task, partition=None, + transform=None, normalisation=['quantile', 'standartize'], max_loaded_files=0): + super().__init__(data_dir=data_dir, io=io, + data_label=data_label, target_label=target_label, + labels=None, + partition=partition, + transform=transform, + max_loaded_files=max_loaded_files) + self.data_dir = data_dir + self.task = task + self.quality = self._get_quantity() + self.normalisation = normalisation + + def _get_quantity(self): + data_dir_split = self.data_dir.split('_') + if 'high' in data_dir_split: + return 'high' + elif 'low' in data_dir_split: + return 'low' + elif 'med' in data_dir_split: + return 'med' + + def quantile_normalize(self, data): + # mask -> min shift -> quantile + + norm_data = np.copy(data) + mask_obj = np.where(norm_data != 0) + mask_back = np.where(norm_data == 0) + norm_data[mask_back] = norm_data.max() + 1.0 # new line + norm_data -= norm_data[mask_obj].min() + + norm_data /= high_res_quantile + + return norm_data + + def standartize(self, data): + # zero mean, unit variance + + standart_data = np.copy(data) + standart_data -= np.mean(standart_data) + std = np.linalg.norm(standart_data, axis=1).max() + if std > 0: + standart_data /= std + + return standart_data + + def __getitem__(self, index): + + item = super().__getitem__(index) + data, target = item['image'], item['distances'] + mask_1 = (np.copy(data) != 0.0).astype(float) # mask for object + mask_2 = np.where(data == 0) # mask for background + + if 'quantile' in self.normalisation: + data = self.quantile_normalize(data) + if 'standartize' in self.normalisation: + data = self.standartize(data) + + dist_new = np.copy(target) + dist_mask = dist_new * mask_1 # select object points + dist_mask[mask_2] = 1.0 # background points has max distance to sharp features + close_to_sharp = np.array((dist_mask != np.nan) & (dist_mask < 1.)).astype(float) + + output = {} + + if self.task == 'two-heads': + # regression + segmentation (or two-head network) has to targets: + # distance field and segmented close-to-sharp region of the object + target = torch.cat([torch.FloatTensor(dist_mask).unsqueeze(0), torch.FloatTensor(close_to_sharp).unsqueeze(0)], dim=0) + output['distance_and_close_to_sharp'] = target + if self.task == 'segmentation': + target = torch.FloatTensor(close_to_sharp).unsqueeze(0) + output['close_to_sharp_mask'] = target + elif self.task == 'regression': + target = torch.FloatTensor(dist_mask).unsqueeze(0) + output['distance_to_sharp'] = target + + data = torch.FloatTensor(data).unsqueeze(0) + output['image'] = data + + return output diff --git a/sharpf/utils/comm.py b/sharpf/utils/comm.py index d58908b8..b9e8d553 100644 --- a/sharpf/utils/comm.py +++ b/sharpf/utils/comm.py @@ -1,3 +1,14 @@ +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import pickle + +import numpy as np +import torch import torch.distributed as dist @@ -15,3 +26,214 @@ def get_batch_size(total_batch_size): f"Total batch size ({total_batch_size}) must be divisible by the number of gpus ({world_size})." batch_size = total_batch_size // world_size return batch_size + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict