diff --git a/examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml b/examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml index 991a87e7..a4e5660f 100644 --- a/examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml +++ b/examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml @@ -2,7 +2,6 @@ defaults: - COCO_TorchvisionFasterRCNN - - override /model/modules@model.modules.preprocessor: tuple_tensorizer_normalizer - override /datamodule: armory_carla_over_objdet_perturbable_mask task_name: "ArmoryCarlaOverObjDet_TorchvisionFasterRCNN" diff --git a/mart/configs/model/modules/tuple_normalizer.yaml b/mart/configs/model/modules/tuple_normalizer.yaml deleted file mode 100644 index 13fdf97a..00000000 --- a/mart/configs/model/modules/tuple_normalizer.yaml +++ /dev/null @@ -1,9 +0,0 @@ -# @package model.modules.preprocessor -_target_: mart.transforms.TupleTransforms -transforms: - _target_: torchvision.transforms.Compose - # Normalize to [0, 1]. - transforms: - - _target_: torchvision.transforms.Normalize - mean: 0 - std: 255 diff --git a/mart/configs/model/modules/tuple_tensorizer_normalizer.yaml b/mart/configs/model/modules/tuple_tensorizer_normalizer.yaml deleted file mode 100644 index 21ac6fb7..00000000 --- a/mart/configs/model/modules/tuple_tensorizer_normalizer.yaml +++ /dev/null @@ -1,14 +0,0 @@ -# @package model.modules.preprocessor -defaults: - - tuple_normalizer - -# Convert dictionary input into tensor input, then normalize to [0, 1]. -transforms: - transforms: - - _target_: mart.transforms.GetItems - keys: ["rgb"] - - _target_: mart.transforms.Cat - dim: 0 - - _target_: torchvision.transforms.Normalize - mean: 0 - std: 255 diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index 1bbd678c..534f0fc9 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -1,7 +1,6 @@ # We simply wrap a torchvision object detection model for validation. defaults: - modular - - /model/modules@modules.preprocessor: tuple_normalizer training_step_log: loss: "loss" @@ -13,6 +12,13 @@ test_sequence: ??? output_preds_key: "losses_and_detections.eval" modules: + preprocessor: + _target_: mart.transforms.TupleTransforms + transforms: + _target_: torchvision.transforms.Normalize + mean: 0 + std: 255 + losses_and_detections: # Return losses in the training mode and predictions in the eval mode in one pass. _target_: mart.models.DualMode diff --git a/mart/datamodules/coco.py b/mart/datamodules/coco.py index 42ddcebb..2a05d806 100644 --- a/mart/datamodules/coco.py +++ b/mart/datamodules/coco.py @@ -68,23 +68,6 @@ def _load_target(self, id: int) -> List[Any]: return {"image_id": id, "file_name": file_name, "annotations": annotations} - def __getitem__(self, index: int): - """Override __getitem__() to dictionarize input for multi-modality datasets. - - This runs after _load_image() and transforms(), while transforms() typically converts - images to tensors. - """ - - image, target_dict = super().__getitem__(index) - - # Convert multi-modal input to a dictionary. - if self.modalities is not None: - # We assume image is a multi-channel tensor, with each modality including 3 channels. - assert image.shape[0] == len(self.modalities) * 3 - image = dict(zip(self.modalities, image.split(3))) - - return image, target_dict - # Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203 def collate_fn(batch): diff --git a/mart/transforms/transforms.py b/mart/transforms/transforms.py index 4c7f29f7..bc67d33a 100644 --- a/mart/transforms/transforms.py +++ b/mart/transforms/transforms.py @@ -7,16 +7,7 @@ import torch from torchvision.transforms import transforms as T -__all__ = [ - "Denormalize", - "Cat", - "Permute", - "Unsqueeze", - "Squeeze", - "Chunk", - "TupleTransforms", - "GetItems", -] +__all__ = ["Denormalize", "Cat", "Permute", "Unsqueeze", "Squeeze", "Chunk", "TupleTransforms"] class Denormalize(T.Normalize): @@ -90,14 +81,3 @@ def __init__(self, transforms): def forward(self, x_tuple): output_tuple = tuple(self.transforms(x) for x in x_tuple) return output_tuple - - -class GetItems: - """Get a list of values with a list of keys from a dictionary.""" - - def __init__(self, keys): - self.keys = keys - - def __call__(self, x): - x_list = [x[key] for key in self.keys] - return x_list