Skip to content

Commit

Permalink
make voxelnet compatible with spconv 1.x and 2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweiy committed Dec 19, 2021
1 parent 0176a49 commit 3fd0b87
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions det3d/models/backbones/scn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
import numpy as np
import spconv
from spconv import SparseConv3d, SubMConv3d
try:
import spconv.pytorch as spconv
from spconv.pytorch import ops
from spconv.pytorch import SparseConv3d, SubMConv3d
except:
import spconv
from spconv import ops
from spconv import SparseConv3d, SubMConv3d

from torch import nn
from torch.nn import functional as F

from ..registry import BACKBONES
from ..utils import build_norm_layer

def replace_feature(out, new_features):
if "replace_feature" in out.__dir__():
# spconv 2.x behaviour
return out.replace_feature(new_features)
else:
out.features = new_features
return out

def conv3x3(in_planes, out_planes, stride=1, indice_key=None, bias=True):
"""3x3 convolution with padding"""
Expand Down Expand Up @@ -65,17 +79,17 @@ def forward(self, x):
identity = x

out = self.conv1(x)
out.features = self.bn1(out.features)
out.features = self.relu(out.features)
out = replace_feature(out, self.bn1(out.features))
out = replace_feature(out, self.relu(out.features))

out = self.conv2(out)
out.features = self.bn2(out.features)
out = replace_feature(out, self.bn2(out.features))

if self.downsample is not None:
identity = self.downsample(x)

out.features += identity.features
out.features = self.relu(out.features)
out = replace_feature(out, out.features + identity.features)
out = replace_feature(out, self.relu(out.features))

return out

Expand Down

0 comments on commit 3fd0b87

Please sign in to comment.