From 308225c63ca44354f52e0f2bcafce4d57892a9d3 Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Tue, 26 Nov 2024 12:37:25 -0600 Subject: [PATCH 1/7] add xfeat --- hloc/extract_features.py | 11 +++++++++++ hloc/extractors/xfeat.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 hloc/extractors/xfeat.py diff --git a/hloc/extract_features.py b/hloc/extract_features.py index f7fd6990..309018e8 100644 --- a/hloc/extract_features.py +++ b/hloc/extract_features.py @@ -124,6 +124,17 @@ "grayscale": False, "resize_max": 1024, }, + }, + "xfeat": { + "output": "feats-xfeat-n5000-r1600", + "model": { + "name": "xfeat", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, }, # Global descriptors "dir": { diff --git a/hloc/extractors/xfeat.py b/hloc/extractors/xfeat.py new file mode 100644 index 00000000..342dec0e --- /dev/null +++ b/hloc/extractors/xfeat.py @@ -0,0 +1,33 @@ +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + + +class XFeat(BaseModel): + default_conf = { + "keypoint_threshold": 0.005, + "max_keypoints": -1, + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "verlab/accelerated_features", + "XFeat", + pretrained=True, + top_k=self.conf["max_keypoints"], + ) + logger.info("Load XFeat(sparse) model done.") + + def _forward(self, data): + pred = self.net.detectAndCompute( + data["image"], top_k=self.conf["max_keypoints"] + )[0] + pred = { + "keypoints": pred["keypoints"][None], + "scores": pred["scores"][None], + "descriptors": pred["descriptors"].T[None], + } + return pred \ No newline at end of file From 75a97e9ec24ae7f50352ed7096616f941e50e185 Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Thu, 28 Nov 2024 12:08:06 -0600 Subject: [PATCH 2/7] add xfeat lighterglue --- hloc/match_features.py | 7 +++++ hloc/matchers/lighterglue.py | 58 ++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 hloc/matchers/lighterglue.py diff --git a/hloc/match_features.py b/hloc/match_features.py index 679e81e9..ad818a0c 100644 --- a/hloc/match_features.py +++ b/hloc/match_features.py @@ -42,6 +42,13 @@ "features": "aliked", }, }, + "xfeat+lighterglue": { + "output": "matches-xfeat-lighterglue", + "model": { + "name": "lighterglue", + "features": "xfeat", + }, + }, "superglue": { "output": "matches-superglue", "model": { diff --git a/hloc/matchers/lighterglue.py b/hloc/matchers/lighterglue.py new file mode 100644 index 00000000..52244f47 --- /dev/null +++ b/hloc/matchers/lighterglue.py @@ -0,0 +1,58 @@ +import torch +from lightglue import LightGlue as LightGlue_ +from ..utils.base_model import BaseModel + + +class LighterGlue(BaseModel): + default_conf_xfeat = { + "name": "lighterglue", # just for interfacing + "input_dim": 64, # input descriptor dimension (autoselected from weights) + "descriptor_dim": 96, + "add_scale_ori": False, + "add_laf": False, # for KeyNetAffNetHardNet + "scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet + "n_layers": 6, + "num_heads": 1, + "flash": True, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "depth_confidence": -1, # early stopping, disable with -1 + "width_confidence": 0.95, # point pruning, disable with -1 + "filter_threshold": 0.1, # match threshold + "weights": None, + } + required_inputs = [ + "image0", + "keypoints0", + "descriptors0", + "image1", + "keypoints1", + "descriptors1", + ] + + def _init(self, conf): + LightGlue_.default_conf = self.default_conf_xfeat + self.net = LightGlue_(None, **conf) + state_dict = torch.hub.load_state_dict_from_url( + "https://github.com/verlab/accelerated_features/raw/main/weights/xfeat-lighterglue.pt" + ) + + # rename old state dict entries + for i in range(self.net.conf.n_layers): + pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" + state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" + state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + state_dict = {k.replace("matcher.", ""): v for k, v in state_dict.items()} + + self.net.load_state_dict(state_dict, strict=False) + + def _forward(self, data): + data["descriptors0"] = data["descriptors0"].transpose(-1, -2) + data["descriptors1"] = data["descriptors1"].transpose(-1, -2) + + return self.net( + { + "image0": {k[:-1]: v for k, v in data.items() if k[-1] == "0"}, + "image1": {k[:-1]: v for k, v in data.items() if k[-1] == "1"}, + } + ) From 2313e4228c8cfc45845f3f54543546c8d4eb608a Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Sun, 1 Dec 2024 19:03:35 -0600 Subject: [PATCH 3/7] Add match_mask parameter to pairs from retrieval for loop closure --- hloc/pairs_from_retrieval.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/hloc/pairs_from_retrieval.py b/hloc/pairs_from_retrieval.py index 32336801..4ccfbc10 100644 --- a/hloc/pairs_from_retrieval.py +++ b/hloc/pairs_from_retrieval.py @@ -81,6 +81,7 @@ def main( db_list=None, db_model=None, db_descriptors=None, + match_mask=None, ): logger.info("Extracting image pairs from a retrieval database.") @@ -108,8 +109,15 @@ def main( query_desc = get_descriptors(query_names, descriptors) sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device)) - # Avoid self-matching - self = np.array(query_names)[:, None] == np.array(db_names)[None] + if match_mask is None: + # Avoid self-matching + self = np.array(query_names)[:, None] == np.array(db_names)[None] + else: + assert match_mask.shape == ( + len(query_names), + len(db_names), + ), "mask shape must match size of query and database images!" + self = match_mask pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) pairs = [(query_names[i], db_names[j]) for i, j in pairs] From 5ee2af2f20c159b23cfc1b5be05a5d6059808244 Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Wed, 4 Dec 2024 10:05:40 -0600 Subject: [PATCH 4/7] add pairs from sequential with quadratic overlap + loop closure --- hloc/pairs_from_sequential.py | 137 ++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 hloc/pairs_from_sequential.py diff --git a/hloc/pairs_from_sequential.py b/hloc/pairs_from_sequential.py new file mode 100644 index 00000000..687ffbf8 --- /dev/null +++ b/hloc/pairs_from_sequential.py @@ -0,0 +1,137 @@ +import os +import argparse +import collections.abc as collections +import numpy as np +from pathlib import Path +from typing import Optional, Union, List + +from hloc import logger +from hloc import pairs_from_retrieval +from hloc.utils.parsers import parse_image_lists, parse_retrieval +from hloc.utils.io import list_h5_names + + +def main( + output: Path, + image_list: Optional[Union[Path, List[str]]] = None, + features: Optional[Path] = None, + window_size: Optional[int] = 10, + quadratic_overlap: bool = True, + use_loop_closure: bool = False, + retrieval_path: Optional[Union[Path, str]] = None, + retrieval_interval: Optional[int] = 2, + num_loc: Optional[int] = 5, +) -> None: + """ + Generate pairs of images based on sequential matching and optional loop closure. + + Args: + output (Path): The output file path where the pairs will be saved. + image_list (Optional[Union[Path, List[str]]]): A path to a file containing a list of images or a list of image names. + features (Optional[Path]): A path to a feature file containing image features. + window_size (Optional[int]): The size of the window for sequential matching. Default is 10. + quadratic_overlap (bool): Whether to use quadratic overlap in sequential matching. Default is True. + use_loop_closure (bool): Whether to use loop closure for additional matching. Default is False. + retrieval_path (Optional[Union[Path, str]]): The path to the retrieval file for loop closure. + retrieval_interval (Optional[int]): The interval for selecting query images for loop closure. Default is 2. + num_loc (Optional[int]): The number of top retrieval matches to consider for loop closure. Default is 5. + + Raises: + ValueError: If neither image_list nor features are provided, or if image_list is of an unknown type. + + Returns: + None + """ + if image_list is not None: + if isinstance(image_list, (str, Path)): + print(image_list) + names_q = parse_image_lists(image_list) + elif isinstance(image_list, collections.Iterable): + names_q = list(image_list) + else: + raise ValueError(f"Unknown type for image list: {image_list}") + elif features is not None: + names_q = list_h5_names(features) + else: + raise ValueError("Provide either a list of images or a feature file.") + + pairs = [] + N = len(names_q) + + for i in range(N - 1): + for j in range(i + 1, min(i + window_size + 1, N)): + pairs.append((names_q[i], names_q[j])) + + if quadratic_overlap: + q = 2 ** (j - i) + if q > window_size and i + q < N: + pairs.append((names_q[i], names_q[i + q])) + + if use_loop_closure: + retrieval_pairs_tmp: Path = output.parent / f"retrieval-pairs-tmp.txt" + + # match mask describes for each image, which images NOT to include in retrevial match search + # I.e., no reason to get retrieval matches for matches already included from sequential matching + + query_list = names_q[::retrieval_interval] + M = len(query_list) + match_mask = np.zeros((M, N), dtype=bool) + + for i in range(M): + for k in range(window_size + 1): + if i * retrieval_interval - k >= 0 and i * retrieval_interval - k < N: + match_mask[i][i * retrieval_interval - k] = 1 + if i * retrieval_interval + k >= 0 and i * retrieval_interval + k < N: + match_mask[i][i * retrieval_interval + k] = 1 + + if quadratic_overlap: + if ( + i * retrieval_interval - 2**k >= 0 + and i * retrieval_interval - 2**k < N + ): + match_mask[i][i * retrieval_interval - 2**k] = 1 + if ( + i * retrieval_interval + 2**k >= 0 + and i * retrieval_interval + 2**k < N + ): + match_mask[i][i * retrieval_interval + 2**k] = 1 + + pairs_from_retrieval.main( + retrieval_path, + retrieval_pairs_tmp, + num_matched=num_loc, + match_mask=match_mask, + db_list=names_q, + query_list=query_list, + ) + + retrieval = parse_retrieval(retrieval_pairs_tmp) + + for key, val in retrieval.items(): + for match in val: + pairs.append((key, match)) + + os.unlink(retrieval_pairs_tmp) + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join([i, j]) for i, j in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Create a list of image pairs based on the sequence of images on alphabetic order" + ) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--image_list", type=Path) + parser.add_argument("--features", type=Path) + parser.add_argument( + "--overlap", type=int, default=10, help="Number of overlapping image pairs" + ) + parser.add_argument( + "--quadratic_overlap", + action="store_true", + help="Whether to match images against their quadratic neighbors.", + ) + args = parser.parse_args() + main(**args.__dict__) From fa2bfffc93fccf7c7ce01b9704ed203e009b06cc Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Wed, 4 Dec 2024 12:36:03 -0600 Subject: [PATCH 5/7] fix line length issues --- hloc/pairs_from_sequential.py | 37 +++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/hloc/pairs_from_sequential.py b/hloc/pairs_from_sequential.py index 687ffbf8..2db13d3d 100644 --- a/hloc/pairs_from_sequential.py +++ b/hloc/pairs_from_sequential.py @@ -27,17 +27,27 @@ def main( Args: output (Path): The output file path where the pairs will be saved. - image_list (Optional[Union[Path, List[str]]]): A path to a file containing a list of images or a list of image names. - features (Optional[Path]): A path to a feature file containing image features. - window_size (Optional[int]): The size of the window for sequential matching. Default is 10. - quadratic_overlap (bool): Whether to use quadratic overlap in sequential matching. Default is True. - use_loop_closure (bool): Whether to use loop closure for additional matching. Default is False. - retrieval_path (Optional[Union[Path, str]]): The path to the retrieval file for loop closure. - retrieval_interval (Optional[int]): The interval for selecting query images for loop closure. Default is 2. - num_loc (Optional[int]): The number of top retrieval matches to consider for loop closure. Default is 5. + image_list (Optional[Union[Path, List[str]]]): + A path to a file containing a list of images or a list of image names. + features (Optional[Path]): + A path to a feature file containing image features. + window_size (Optional[int]): + The size of the window for sequential matching. Default is 10. + quadratic_overlap (bool): + Whether to use quadratic overlap in sequential matching. Default is True. + use_loop_closure (bool): + Whether to use loop closure for additional matching. Default is False. + retrieval_path (Optional[Union[Path, str]]): + The path to the retrieval file for loop closure. + retrieval_interval (Optional[int]): + The interval for selecting query images for loop closure. Default is 2. + num_loc (Optional[int]): + The number of top retrieval matches to consider for loop closure. + Default is 5. Raises: - ValueError: If neither image_list nor features are provided, or if image_list is of an unknown type. + ValueError: If neither image_list nor features are provided, + or if image_list is of an unknown type. Returns: None @@ -70,8 +80,9 @@ def main( if use_loop_closure: retrieval_pairs_tmp: Path = output.parent / f"retrieval-pairs-tmp.txt" - # match mask describes for each image, which images NOT to include in retrevial match search - # I.e., no reason to get retrieval matches for matches already included from sequential matching + # match mask describes for each image, which images NOT to include in retrevial + # match search I.e., no reason to get retrieval matches for matches + # already included from sequential matching query_list = names_q[::retrieval_interval] M = len(query_list) @@ -120,7 +131,9 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Create a list of image pairs based on the sequence of images on alphabetic order" + description=""" + Create a list of image pairs basedon the sequence of images on alphabetic order + """ ) parser.add_argument("--output", required=True, type=Path) parser.add_argument("--image_list", type=Path) From 0224017a50039efef5fbbf92ade07a70d737525f Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Wed, 4 Dec 2024 12:41:39 -0600 Subject: [PATCH 6/7] more linter fixes --- hloc/extract_features.py | 2 +- hloc/extractors/xfeat.py | 2 +- hloc/matchers/lighterglue.py | 5 ++--- hloc/pairs_from_sequential.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/hloc/extract_features.py b/hloc/extract_features.py index 309018e8..5594d567 100644 --- a/hloc/extract_features.py +++ b/hloc/extract_features.py @@ -125,7 +125,7 @@ "resize_max": 1024, }, }, - "xfeat": { + "xfeat": { "output": "feats-xfeat-n5000-r1600", "model": { "name": "xfeat", diff --git a/hloc/extractors/xfeat.py b/hloc/extractors/xfeat.py index 342dec0e..5dc230f2 100644 --- a/hloc/extractors/xfeat.py +++ b/hloc/extractors/xfeat.py @@ -30,4 +30,4 @@ def _forward(self, data): "scores": pred["scores"][None], "descriptors": pred["descriptors"].T[None], } - return pred \ No newline at end of file + return pred diff --git a/hloc/matchers/lighterglue.py b/hloc/matchers/lighterglue.py index 52244f47..0bcef9fb 100644 --- a/hloc/matchers/lighterglue.py +++ b/hloc/matchers/lighterglue.py @@ -32,9 +32,8 @@ class LighterGlue(BaseModel): def _init(self, conf): LightGlue_.default_conf = self.default_conf_xfeat self.net = LightGlue_(None, **conf) - state_dict = torch.hub.load_state_dict_from_url( - "https://github.com/verlab/accelerated_features/raw/main/weights/xfeat-lighterglue.pt" - ) + url = "https://github.com/verlab/accelerated_features/raw/main/weights/xfeat-lighterglue.pt" # noqa: E501 + state_dict = torch.hub.load_state_dict_from_url(url) # rename old state dict entries for i in range(self.net.conf.n_layers): diff --git a/hloc/pairs_from_sequential.py b/hloc/pairs_from_sequential.py index 2db13d3d..2f240827 100644 --- a/hloc/pairs_from_sequential.py +++ b/hloc/pairs_from_sequential.py @@ -78,7 +78,7 @@ def main( pairs.append((names_q[i], names_q[i + q])) if use_loop_closure: - retrieval_pairs_tmp: Path = output.parent / f"retrieval-pairs-tmp.txt" + retrieval_pairs_tmp: Path = output.parent / "retrieval-pairs-tmp.txt" # match mask describes for each image, which images NOT to include in retrevial # match search I.e., no reason to get retrieval matches for matches From 56ac6a132d5d4983ffb12049652d6d1f2cee8553 Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Wed, 4 Dec 2024 12:44:06 -0600 Subject: [PATCH 7/7] more linter fixes --- hloc/matchers/lighterglue.py | 1 + hloc/pairs_from_sequential.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/hloc/matchers/lighterglue.py b/hloc/matchers/lighterglue.py index 0bcef9fb..c5752d92 100644 --- a/hloc/matchers/lighterglue.py +++ b/hloc/matchers/lighterglue.py @@ -1,5 +1,6 @@ import torch from lightglue import LightGlue as LightGlue_ + from ..utils.base_model import BaseModel diff --git a/hloc/pairs_from_sequential.py b/hloc/pairs_from_sequential.py index 2f240827..2bb9811d 100644 --- a/hloc/pairs_from_sequential.py +++ b/hloc/pairs_from_sequential.py @@ -1,14 +1,14 @@ -import os import argparse import collections.abc as collections -import numpy as np +import os from pathlib import Path -from typing import Optional, Union, List +from typing import List, Optional, Union -from hloc import logger -from hloc import pairs_from_retrieval -from hloc.utils.parsers import parse_image_lists, parse_retrieval +import numpy as np + +from hloc import logger, pairs_from_retrieval from hloc.utils.io import list_h5_names +from hloc.utils.parsers import parse_image_lists, parse_retrieval def main(