diff --git a/mmcv/cnn/bricks/wrappers.py b/mmcv/cnn/bricks/wrappers.py index a464f86dc1..6e125b41ca 100644 --- a/mmcv/cnn/bricks/wrappers.py +++ b/mmcv/cnn/bricks/wrappers.py @@ -128,8 +128,8 @@ def forward(self, x): class MaxPool2d(nn.MaxPool2d): def forward(self, x): - # PyTorch 1.7 does not support empty tensor inference yet - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), _pair(self.padding), _pair(self.stride), @@ -146,8 +146,8 @@ def forward(self, x): class MaxPool3d(nn.MaxPool3d): def forward(self, x): - # PyTorch 1.7 does not support empty tensor inference yet - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), _triple(self.padding), diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 326cfd2d2a..ffc933fec2 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -330,7 +330,7 @@ def test_linear(in_w, in_h, in_feature, out_feature): wrapper(x_empty) -@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 8)) +@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10)) def test_nn_op_forward_called(): for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: @@ -347,6 +347,20 @@ def test_nn_op_forward_called(): wrapper(x_normal) nn_module_forward.assert_called_with(x_normal) + for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']: + with patch(f'torch.nn.{m}.forward') as nn_module_forward: + # randn input + x_empty = torch.randn(0, 3, 10, 10, 10) + wrapper = eval(m)(3, 2, 1) + wrapper(x_empty) + nn_module_forward.assert_called_with(x_empty) + + # non-randn input + x_normal = torch.randn(1, 3, 10, 10, 10) + wrapper = eval(m)(3, 2, 1) + wrapper(x_normal) + nn_module_forward.assert_called_with(x_normal) + with patch('torch.nn.Linear.forward') as nn_module_forward: # randn input x_empty = torch.randn(0, 3)