Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change number of coco classes in detection recipe #5999

Open
rvandeghen opened this issue May 12, 2022 · 3 comments
Open

Change number of coco classes in detection recipe #5999

rvandeghen opened this issue May 12, 2022 · 3 comments

Comments

@rvandeghen
Copy link
Contributor

🚀 The feature

Infer the number of classes from the dataset without hard-coding it, for example with

def get_dataset(name, image_set, transform, data_path):
    paths = {"coco": (data_path, get_coco), "coco_kp": (data_path, get_coco_kp)}
    p, ds_fn = paths[name]

    ds = ds_fn(p, image_set=image_set, transforms=transform)
    num_classes = len(ds.coco.cats)
    return ds, num_classes

This means that we need to remap the classes with continuous indexes, for example with

class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, transforms):
        super().__init__(img_folder, ann_file)
        self._transforms = transforms
        self.mapping = self.map_coco()

    def map_coco(self,):
        mapping_coco = {}
        for idx, c in enumerate(self.coco.cats):
            mapping_coco[c] = idx+1
        return mapping_coco


    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        target = copy.deepcopy(target)
        for t in target:
            t['category_id'] = self.mapping[t['category_id']]
        image_id = self.ids[idx]
        target = dict(image_id=image_id, annotations=target)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target

Lastly, the number of classes for the model could be num_classes+1

model = torchvision.models.detection.__dict__[args.model](
     weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes+1, **kwargs
)

Motivation, pitch

The number of classes to train coco is hard coded and is 90+1.

paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}

This value does not reflect the number of classes that coco has (80) but the highest index since some indexes are not used.

Since the number of classes defines the head of detection architectures

def __init__(self, in_channels, num_classes):
super().__init__()
self.cls_score = nn.Linear(in_channels, num_classes)
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)

it results in a small increase of memory usage (and may result in strange behaviors).

Alternatives

No response

Additional context

No response

@datumbox
Copy link
Contributor

@rvandeghen Thanks for the proposal.

I agree it's annoying the fact that all of TorchVision's models keep the 10 unused categories. This has been like this for years and unfortunately "fixing" it now means breaking some models. There might be a mechanism that could allow us to keep the old models as is and fix it for new ones, but I suspect it will introduce quite a lot of extra code.

So my OCD is definitely annoyed but this issue whenever I see it but I am tempted to say it's probably a nofix issue because of the implications that this can have on BC.

Let me know what you think.

@rvandeghen
Copy link
Contributor Author

@datumbox I fully agree that there is a feeling of “it's too late now“. However I thought that it would have been the good timing since we have done #2707 and #5307.
There are currently 8 detection models so I think retraining them should not be a huge effort but it could be worth to use that strategy for the upcoming ones.

I can retrain some of them if necessary and provide the weights.

Let me know.

@datumbox
Copy link
Contributor

datumbox commented May 12, 2022

The challenge is that we will need 2 versions of meta-data for the labels, then custom mechanisms on the reference scripts so that one can validate/do-inference with the old models while train the new ones with reduced classes. Then we would need to update our documentation and code examples to indicate that some models are on a reduced category list etc. Perhaps there are more things we need to change.

Unfortunately we can't really replace the existing 8 models with the new ones. This would violate the strong BC guarantees that TorchVision offers. The use-case in mind is transfer learning. Someone might have already fit new models on top of our weights and now they wont work or they wont be reproducible.

As much as I would like to fix this, I think this is not something we can fix without breaking BC or adding a lot of extra code to tackle corner-cases. I think I'll leave this proposal open, in case we can adopt it if we introduce a major BC-breaking change on the area of Detection. There are quite a few things we need to do (such as supporting multiple losses, removing transforms from within the model etc) and many of them break BC. Perhaps we can sneak in this change then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants