Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug when the target is empty in FCOS #5267

Merged
merged 4 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,17 @@ def test_forward_negative_sample_retinanet(self):

assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))

def test_forward_negative_sample_fcos(self):
model = torchvision.models.detection.fcos_resnet50_fpn(
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False
)

images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0))

def test_forward_negative_sample_ssd(self):
model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False)

Expand Down
23 changes: 14 additions & 9 deletions torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,13 @@ def compute_loss(
all_gt_classes_targets = []
all_gt_boxes_targets = []
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
if len(targets_per_image["labels"]) == 0:
Copy link
Contributor

@datumbox datumbox Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to confirm that this will works as expected and doesn't produce errors when find_unused_params=False (see discussion at #2784 (comment))

The runtime error to be looking for is something like:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).

@jdsgomes Could you confirm by kicking off a run for an epoch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox thanks. I have not considered this problem. By the way, I want to know that how retinanet can handle this.

Copy link
Contributor

@datumbox datumbox Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had this issue previously with a couple of models. Retina was one of them. The way we avoided it was by rewriting the loss estimations in a way that the vectors can cope with empty indeces. Let's check first if it is an issue before starting rewriting.

Edit: I found the PR with the patch: #3032

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox I think the current implementation will not have this trouble, as you see, the regression loss and centerness loss is just in the way that the vectors can cope with empty indeces. but for safe, we had better check once. but I don't have the empty datasets, so can you help to check this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's our understanding as well, that you should be OK. A single run on the scripts for 1 epoch should be enough to confirm if it's a problem (at least that was the case previously). I'll sync with Joao to confirm. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the other user confirmed the patch works, see #5266 (comment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that the patch works and also that after training for an epoch no runtime errors are observed.

gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
else:
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
all_gt_classes_targets.append(gt_classes_targets)
all_gt_boxes_targets.append(gt_boxes_targets)

Expand Down Expand Up @@ -95,13 +99,14 @@ def compute_loss(
]
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
if len(bbox_reg_targets) == 0:
bbox_reg_targets.new_zeros(len(bbox_reg_targets))
left_right = bbox_reg_targets[:, :, [0, 2]]
top_bottom = bbox_reg_targets[:, :, [1, 3]]
gt_ctrness_targets = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
)
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
else:
left_right = bbox_reg_targets[:, :, [0, 2]]
top_bottom = bbox_reg_targets[:, :, [1, 3]]
gt_ctrness_targets = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
)
pred_centerness = bbox_ctrness.squeeze(dim=2)
loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
Expand Down