Skip to content

Commit

Permalink
Merge branch 'main' into sarlinpe/pin-black
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe authored Feb 16, 2024
2 parents 5ffd84a + 0356125 commit 76613fb
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf).

We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629), [DISK](https://arxiv.org/abs/2006.13566), [ALIKED](https://arxiv.org/abs/2304.03608) and [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf) local features.
The training end evaluation code can be found in our training library [glue-factory](https://github.com/cvg/glue-factory/).
The training and evaluation code can be found in our library [glue-factory](https://github.com/cvg/glue-factory/).

## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb)

Expand All @@ -43,7 +43,7 @@ We provide a [demo notebook](demo.ipynb) which shows how to perform feature extr
Here is a minimal script to match two images:

```python
from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED
from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
from lightglue.utils import load_image, rbd

# SuperPoint+LightGlue
Expand Down
1 change: 1 addition & 0 deletions lightglue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .aliked import ALIKED # noqa
from .disk import DISK # noqa
from .dog_hardnet import DoGHardNet # noqa
from .lightglue import LightGlue # noqa
from .sift import SIFT # noqa
from .superpoint import SuperPoint # noqa
Expand Down
41 changes: 41 additions & 0 deletions lightglue/dog_hardnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from kornia.color import rgb_to_grayscale
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori

from .sift import SIFT


class DoGHardNet(SIFT):
required_data_keys = ["image"]

def __init__(self, **conf):
super().__init__(**conf)
self.laf_desc = LAFDescriptor(HardNet(True)).eval()

def forward(self, data: dict) -> dict:
image = data["image"]
if image.shape[1] == 3:
image = rgb_to_grayscale(image)
device = image.device
self.laf_desc = self.laf_desc.to(device)
self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
pred = []
if "image_size" in data.keys():
im_size = data.get("image_size").long()
else:
im_size = None
for k in range(len(image)):
img = image[k]
if im_size is not None:
w, h = data["image_size"][k]
img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
p = self.extract_single_image(img)
lafs = laf_from_center_scale_ori(
p["keypoints"].reshape(1, -1, 2),
6.0 * p["scales"].reshape(1, -1, 1, 1),
torch.rad2deg(p["oris"]).reshape(1, -1, 1),
).to(device)
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
pred.append(p)
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
return pred
39 changes: 34 additions & 5 deletions lightglue/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, allow_flash: bool) -> None:
torch.backends.cuda.enable_flash_sdp(allow_flash)

def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if q.shape[-2] == 0 or k.shape[-2] == 0:
return q.new_zeros((*q.shape[:-1], v.shape[-1]))
if self.enable_flash and q.device.type == "cuda":
# use torch 2.0 scaled_dot_product_attention with flash
if self.has_sdp:
Expand Down Expand Up @@ -357,6 +359,11 @@ class LightGlue(nn.Module):
"input_dim": 128,
"add_scale_ori": True,
},
"doghardnet": {
"weights": "doghardnet_lightglue",
"input_dim": 128,
"add_scale_ori": True,
},
}

def __init__(self, features="superpoint", **conf) -> None:
Expand Down Expand Up @@ -518,6 +525,8 @@ def _forward(self, data: dict) -> dict:
prune1 = torch.ones_like(ind1)
token0, token1 = None, None
for i in range(self.conf.n_layers):
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
break
desc0, desc1 = self.transformers[i](
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
)
Expand All @@ -526,7 +535,7 @@ def _forward(self, data: dict) -> dict:

if do_early_stop:
token0, token1 = self.token_confidence[i](desc0, desc1)
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
break
if do_point_pruning and desc0.shape[-2] > pruning_th:
scores0 = self.log_assignment[i].get_matchability(desc0)
Expand All @@ -545,7 +554,29 @@ def _forward(self, data: dict) -> dict:
encoding1 = encoding1.index_select(-2, keep1)
prune1[:, ind1] += 1

desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
m0 = desc0.new_full((b, m), -1, dtype=torch.long)
m1 = desc1.new_full((b, n), -1, dtype=torch.long)
mscores0 = desc0.new_zeros((b, m))
mscores1 = desc1.new_zeros((b, n))
matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
mscores = desc0.new_empty((b, 0))
if not do_point_pruning:
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
"matching_scores1": mscores1,
"stop": i + 1,
"matches": matches,
"scores": mscores,
"prune0": prune0,
"prune1": prune1,
}

desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
scores, _ = self.log_assignment[i](desc0, desc1)
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
matches, mscores = [], []
Expand Down Expand Up @@ -574,7 +605,7 @@ def _forward(self, data: dict) -> dict:
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
prune1 = torch.ones_like(mscores1) * self.conf.n_layers

pred = {
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
Expand All @@ -586,8 +617,6 @@ def _forward(self, data: dict) -> dict:
"prune1": prune1,
}

return pred

def confidence_threshold(self, layer_index: int) -> float:
"""scaled confidence threshold"""
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
Expand Down

0 comments on commit 76613fb

Please sign in to comment.