Skip to content
This repository has been archived by the owner on Dec 15, 2021. It is now read-only.

Commit

Permalink
add 'once_differentiable' for dcn and modify 'configs/cityscapes/READ…
Browse files Browse the repository at this point in the history
…ME.md' (facebookresearch#701)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed

* add GroupNorm

* add GroupNorm -- sort out yaml files

* use torch.nn.GroupNorm instead, replace 'use_gn' with 'conv_block' and use 'BaseStem'&'Bottleneck' to simply codes

* modification on 'group_norm' and 'conv_with_kaiming_uniform' function

* modification on yaml files in configs/gn_baselines/ and reduce the amount of indentation and code duplication

* use 'kaiming_uniform' to initialize resnet, disable gn after fc layer, and add dilation into ResNetHead

* agnostic-regression for bbox

* please set 'STRIDE_IN_1X1' to be 'False' when backbone use GN

* add README.md for GN

* add dcn from mmdetection

* add documentation for finetuning cityscapes

* add documentation for finetuning cityscapes

* add documentation for finetuning cityscapes

* add 'once_differentiable' for dcn and modify 'configs/cityscapes/README.md'
  • Loading branch information
zimenglan-sysu-512 authored and fmassa committed Apr 20, 2019
1 parent e876521 commit c9210b1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
8 changes: 6 additions & 2 deletions configs/cityscapes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,13 @@ def clip_weights_from_pretrain_of_coco_to_cityscapes(f, out_file):
print("f: {}\nout_file: {}".format(f, out_file))
torch.save(m, out_file)
```
Step 3: modify the `input&solver` configuration in the `yaml` file, like this:
Step 3: modify the `input&weight&solver` configuration in the `yaml` file, like this:
```
MODEL:
WEIGHT: "xxx.pth" # the model u save from above code
INPUT:
MIN_SIZE_TRAIN: (800, 832, 863, 896, 928, 960, 992, 1024, 1024)
MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024, 1024)
MAX_SIZE_TRAIN: 2048
MIN_SIZE_TEST: 1024
MAX_SIZE_TEST: 2048
Expand All @@ -210,4 +213,5 @@ SOLVER:
STEPS: (3000,)
MAX_ITER: 4000
```
Step 4: train the model.

3 changes: 3 additions & 0 deletions maskrcnn_benchmark/layers/dcn/deform_conv_func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from maskrcnn_benchmark import _C
Expand Down Expand Up @@ -67,6 +68,7 @@ def forward(
return output

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors

Expand Down Expand Up @@ -201,6 +203,7 @@ def forward(
return output

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
Expand Down
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/layers/dcn/deform_pool_func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from maskrcnn_benchmark import _C

Expand Down Expand Up @@ -60,6 +61,7 @@ def forward(
return output

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
Expand Down

0 comments on commit c9210b1

Please sign in to comment.