Skip to content

Commit

Permalink
Adds DoG-HardNet model (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
ducha-aiki authored Jan 23, 2024
1 parent b1cd942 commit be49528
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ We provide a [demo notebook](demo.ipynb) which shows how to perform feature extr
Here is a minimal script to match two images:

```python
from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED
from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
from lightglue.utils import load_image, rbd

# SuperPoint+LightGlue
Expand Down
1 change: 1 addition & 0 deletions lightglue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .aliked import ALIKED # noqa
from .disk import DISK # noqa
from .dog_hardnet import DoGHardNet # noqa
from .lightglue import LightGlue # noqa
from .sift import SIFT # noqa
from .superpoint import SuperPoint # noqa
Expand Down
41 changes: 41 additions & 0 deletions lightglue/dog_hardnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from kornia.color import rgb_to_grayscale
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori

from .sift import SIFT


class DoGHardNet(SIFT):
required_data_keys = ["image"]

def __init__(self, **conf):
super().__init__(**conf)
self.laf_desc = LAFDescriptor(HardNet(True)).eval()

def forward(self, data: dict) -> dict:
image = data["image"]
if image.shape[1] == 3:
image = rgb_to_grayscale(image)
device = image.device
self.laf_desc = self.laf_desc.to(device)
self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
pred = []
if "image_size" in data.keys():
im_size = data.get("image_size").long()
else:
im_size = None
for k in range(len(image)):
img = image[k]
if im_size is not None:
w, h = data["image_size"][k]
img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
p = self.extract_single_image(img)
lafs = laf_from_center_scale_ori(
p["keypoints"].reshape(1, -1, 2),
6.0 * p["scales"].reshape(1, -1, 1, 1),
torch.rad2deg(p["oris"]).reshape(1, -1, 1),
).to(device)
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
pred.append(p)
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
return pred
5 changes: 5 additions & 0 deletions lightglue/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ class LightGlue(nn.Module):
"input_dim": 128,
"add_scale_ori": True,
},
"doghardnet": {
"weights": "doghardnet_lightglue",
"input_dim": 128,
"add_scale_ori": True,
},
}

def __init__(self, features="superpoint", **conf) -> None:
Expand Down

0 comments on commit be49528

Please sign in to comment.