Skip to content

Commit

Permalink
Revert "Make input a dictionary for multi-modal object detection (#95)"
Browse files Browse the repository at this point in the history
This reverts commit de77a9d.
  • Loading branch information
mzweilin committed Jul 14, 2023
1 parent 9a2531f commit c5bf847
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

defaults:
- COCO_TorchvisionFasterRCNN
- override /model/[email protected]: tuple_tensorizer_normalizer
- override /datamodule: armory_carla_over_objdet_perturbable_mask

task_name: "ArmoryCarlaOverObjDet_TorchvisionFasterRCNN"
Expand Down
9 changes: 0 additions & 9 deletions mart/configs/model/modules/tuple_normalizer.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions mart/configs/model/modules/tuple_tensorizer_normalizer.yaml

This file was deleted.

8 changes: 7 additions & 1 deletion mart/configs/model/torchvision_object_detection.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# We simply wrap a torchvision object detection model for validation.
defaults:
- modular
- /model/[email protected]: tuple_normalizer

training_step_log:
loss: "loss"
Expand All @@ -13,6 +12,13 @@ test_sequence: ???
output_preds_key: "losses_and_detections.eval"

modules:
preprocessor:
_target_: mart.transforms.TupleTransforms
transforms:
_target_: torchvision.transforms.Normalize
mean: 0
std: 255

losses_and_detections:
# Return losses in the training mode and predictions in the eval mode in one pass.
_target_: mart.models.DualMode
Expand Down
17 changes: 0 additions & 17 deletions mart/datamodules/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,6 @@ def _load_target(self, id: int) -> List[Any]:

return {"image_id": id, "file_name": file_name, "annotations": annotations}

def __getitem__(self, index: int):
"""Override __getitem__() to dictionarize input for multi-modality datasets.
This runs after _load_image() and transforms(), while transforms() typically converts
images to tensors.
"""

image, target_dict = super().__getitem__(index)

# Convert multi-modal input to a dictionary.
if self.modalities is not None:
# We assume image is a multi-channel tensor, with each modality including 3 channels.
assert image.shape[0] == len(self.modalities) * 3
image = dict(zip(self.modalities, image.split(3)))

return image, target_dict


# Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203
def collate_fn(batch):
Expand Down
22 changes: 1 addition & 21 deletions mart/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@
import torch
from torchvision.transforms import transforms as T

__all__ = [
"Denormalize",
"Cat",
"Permute",
"Unsqueeze",
"Squeeze",
"Chunk",
"TupleTransforms",
"GetItems",
]
__all__ = ["Denormalize", "Cat", "Permute", "Unsqueeze", "Squeeze", "Chunk", "TupleTransforms"]


class Denormalize(T.Normalize):
Expand Down Expand Up @@ -90,14 +81,3 @@ def __init__(self, transforms):
def forward(self, x_tuple):
output_tuple = tuple(self.transforms(x) for x in x_tuple)
return output_tuple


class GetItems:
"""Get a list of values with a list of keys from a dictionary."""

def __init__(self, keys):
self.keys = keys

def __call__(self, x):
x_list = [x[key] for key in self.keys]
return x_list

0 comments on commit c5bf847

Please sign in to comment.