forked from cvg/Hierarchical-Localization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sift.py
68 lines (53 loc) · 1.93 KB
/
sift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import copy
import numpy as np
import torch
from ..utils.base_model import BaseModel
import pycolmap
EPS = 1e-6
def sift_to_rootsift(x):
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
x = np.sqrt(x.clip(min=EPS))
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
return x
class SIFT(BaseModel):
default_conf = {
'num_octaves': 4,
'octave_resolution': 3,
'first_octave': 0,
'edge_thresh': 10,
'peak_thresh': 0.01,
'upright': False,
'root': True,
'max_keypoints': -1
}
required_inputs = ['image']
def _init(self, conf):
self.root = conf['root']
self.max_keypoints = conf['max_keypoints']
vlfeat_conf = copy.deepcopy(conf)
vlfeat_conf.pop('name', None)
vlfeat_conf.pop('root', None)
vlfeat_conf.pop('max_keypoints', None)
self.extract = lambda image: pycolmap.extract_sift(
image, **vlfeat_conf
)
def _forward(self, data):
image = data['image'].cpu().numpy()
assert image.shape[1] == 1
assert image.min() >= -EPS and image.max() <= 1 + EPS
keypoints, scores, descriptors = self.extract(image[0, 0])
keypoints = keypoints[:, : 2] # Keep only x, y.
if self.root:
descriptors = sift_to_rootsift(descriptors)
if self.max_keypoints != -1:
# TODO: check that the scores from PyCOLMAP are 100% correct,
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
indices = np.argsort(scores)[:: -1][: self.max_keypoints]
keypoints = keypoints[indices, :]
scores = scores[indices]
descriptors = descriptors[indices, :]
return {
'keypoints': torch.from_numpy(keypoints)[None],
'scores': torch.from_numpy(scores)[None],
'descriptors': torch.from_numpy(descriptors.T)[None],
}