Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add xfeat+lighterglue and pairs from sequential #441

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@
"resize_max": 1024,
},
},
"xfeat": {
"output": "feats-xfeat-n5000-r1600",
"model": {
"name": "xfeat",
"max_keypoints": 5000,
},
"preprocessing": {
"grayscale": False,
"resize_max": 1600,
},
},
# Global descriptors
"dir": {
"output": "global-feats-dir",
Expand Down
33 changes: 33 additions & 0 deletions hloc/extractors/xfeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch

from hloc import logger

from ..utils.base_model import BaseModel


class XFeat(BaseModel):
default_conf = {
"keypoint_threshold": 0.005,
"max_keypoints": -1,
}
required_inputs = ["image"]

def _init(self, conf):
self.net = torch.hub.load(
"verlab/accelerated_features",
"XFeat",
pretrained=True,
top_k=self.conf["max_keypoints"],
)
logger.info("Load XFeat(sparse) model done.")

def _forward(self, data):
pred = self.net.detectAndCompute(
data["image"], top_k=self.conf["max_keypoints"]
)[0]
pred = {
"keypoints": pred["keypoints"][None],
"scores": pred["scores"][None],
"descriptors": pred["descriptors"].T[None],
}
return pred
7 changes: 7 additions & 0 deletions hloc/match_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
"features": "aliked",
},
},
"xfeat+lighterglue": {
"output": "matches-xfeat-lighterglue",
"model": {
"name": "lighterglue",
"features": "xfeat",
},
},
"superglue": {
"output": "matches-superglue",
"model": {
Expand Down
58 changes: 58 additions & 0 deletions hloc/matchers/lighterglue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from lightglue import LightGlue as LightGlue_

from ..utils.base_model import BaseModel


class LighterGlue(BaseModel):
default_conf_xfeat = {
"name": "lighterglue", # just for interfacing
"input_dim": 64, # input descriptor dimension (autoselected from weights)
"descriptor_dim": 96,
"add_scale_ori": False,
"add_laf": False, # for KeyNetAffNetHardNet
"scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet
"n_layers": 6,
"num_heads": 1,
"flash": True, # enable FlashAttention if available.
"mp": False, # enable mixed precision
"depth_confidence": -1, # early stopping, disable with -1
"width_confidence": 0.95, # point pruning, disable with -1
"filter_threshold": 0.1, # match threshold
"weights": None,
}
required_inputs = [
"image0",
"keypoints0",
"descriptors0",
"image1",
"keypoints1",
"descriptors1",
]

def _init(self, conf):
LightGlue_.default_conf = self.default_conf_xfeat
self.net = LightGlue_(None, **conf)
url = "https://github.com/verlab/accelerated_features/raw/main/weights/xfeat-lighterglue.pt" # noqa: E501
state_dict = torch.hub.load_state_dict_from_url(url)

# rename old state dict entries
for i in range(self.net.conf.n_layers):
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
state_dict = {k.replace("matcher.", ""): v for k, v in state_dict.items()}

self.net.load_state_dict(state_dict, strict=False)

def _forward(self, data):
data["descriptors0"] = data["descriptors0"].transpose(-1, -2)
data["descriptors1"] = data["descriptors1"].transpose(-1, -2)

return self.net(
{
"image0": {k[:-1]: v for k, v in data.items() if k[-1] == "0"},
"image1": {k[:-1]: v for k, v in data.items() if k[-1] == "1"},
}
)
12 changes: 10 additions & 2 deletions hloc/pairs_from_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def main(
db_list=None,
db_model=None,
db_descriptors=None,
match_mask=None,
):
logger.info("Extracting image pairs from a retrieval database.")

Expand Down Expand Up @@ -108,8 +109,15 @@ def main(
query_desc = get_descriptors(query_names, descriptors)
sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device))

# Avoid self-matching
self = np.array(query_names)[:, None] == np.array(db_names)[None]
if match_mask is None:
# Avoid self-matching
self = np.array(query_names)[:, None] == np.array(db_names)[None]
else:
assert match_mask.shape == (
len(query_names),
len(db_names),
), "mask shape must match size of query and database images!"
self = match_mask
pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0)
pairs = [(query_names[i], db_names[j]) for i, j in pairs]

Expand Down
150 changes: 150 additions & 0 deletions hloc/pairs_from_sequential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import argparse
import collections.abc as collections
import os
from pathlib import Path
from typing import List, Optional, Union

import numpy as np

from hloc import logger, pairs_from_retrieval
from hloc.utils.io import list_h5_names
from hloc.utils.parsers import parse_image_lists, parse_retrieval


def main(
output: Path,
image_list: Optional[Union[Path, List[str]]] = None,
features: Optional[Path] = None,
window_size: Optional[int] = 10,
quadratic_overlap: bool = True,
use_loop_closure: bool = False,
retrieval_path: Optional[Union[Path, str]] = None,
retrieval_interval: Optional[int] = 2,
num_loc: Optional[int] = 5,
) -> None:
"""
Generate pairs of images based on sequential matching and optional loop closure.

Args:
output (Path): The output file path where the pairs will be saved.
image_list (Optional[Union[Path, List[str]]]):
A path to a file containing a list of images or a list of image names.
features (Optional[Path]):
A path to a feature file containing image features.
window_size (Optional[int]):
The size of the window for sequential matching. Default is 10.
quadratic_overlap (bool):
Whether to use quadratic overlap in sequential matching. Default is True.
use_loop_closure (bool):
Whether to use loop closure for additional matching. Default is False.
retrieval_path (Optional[Union[Path, str]]):
The path to the retrieval file for loop closure.
retrieval_interval (Optional[int]):
The interval for selecting query images for loop closure. Default is 2.
num_loc (Optional[int]):
The number of top retrieval matches to consider for loop closure.
Default is 5.

Raises:
ValueError: If neither image_list nor features are provided,
or if image_list is of an unknown type.

Returns:
None
"""
if image_list is not None:
if isinstance(image_list, (str, Path)):
print(image_list)
names_q = parse_image_lists(image_list)
elif isinstance(image_list, collections.Iterable):
names_q = list(image_list)
else:
raise ValueError(f"Unknown type for image list: {image_list}")
elif features is not None:
names_q = list_h5_names(features)
else:
raise ValueError("Provide either a list of images or a feature file.")

pairs = []
N = len(names_q)

for i in range(N - 1):
for j in range(i + 1, min(i + window_size + 1, N)):
pairs.append((names_q[i], names_q[j]))

if quadratic_overlap:
q = 2 ** (j - i)
if q > window_size and i + q < N:
pairs.append((names_q[i], names_q[i + q]))

if use_loop_closure:
retrieval_pairs_tmp: Path = output.parent / "retrieval-pairs-tmp.txt"

# match mask describes for each image, which images NOT to include in retrevial
# match search I.e., no reason to get retrieval matches for matches
# already included from sequential matching

query_list = names_q[::retrieval_interval]
M = len(query_list)
match_mask = np.zeros((M, N), dtype=bool)

for i in range(M):
for k in range(window_size + 1):
if i * retrieval_interval - k >= 0 and i * retrieval_interval - k < N:
match_mask[i][i * retrieval_interval - k] = 1
if i * retrieval_interval + k >= 0 and i * retrieval_interval + k < N:
match_mask[i][i * retrieval_interval + k] = 1

if quadratic_overlap:
if (
i * retrieval_interval - 2**k >= 0
and i * retrieval_interval - 2**k < N
):
match_mask[i][i * retrieval_interval - 2**k] = 1
if (
i * retrieval_interval + 2**k >= 0
and i * retrieval_interval + 2**k < N
):
match_mask[i][i * retrieval_interval + 2**k] = 1

pairs_from_retrieval.main(
retrieval_path,
retrieval_pairs_tmp,
num_matched=num_loc,
match_mask=match_mask,
db_list=names_q,
query_list=query_list,
)

retrieval = parse_retrieval(retrieval_pairs_tmp)

for key, val in retrieval.items():
for match in val:
pairs.append((key, match))

os.unlink(retrieval_pairs_tmp)

logger.info(f"Found {len(pairs)} pairs.")
with open(output, "w") as f:
f.write("\n".join(" ".join([i, j]) for i, j in pairs))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
Create a list of image pairs basedon the sequence of images on alphabetic order
"""
)
parser.add_argument("--output", required=True, type=Path)
parser.add_argument("--image_list", type=Path)
parser.add_argument("--features", type=Path)
parser.add_argument(
"--overlap", type=int, default=10, help="Number of overlapping image pairs"
)
parser.add_argument(
"--quadratic_overlap",
action="store_true",
help="Whether to match images against their quadratic neighbors.",
)
args = parser.parse_args()
main(**args.__dict__)
Loading