forked from cvg/limap
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add LBD, L2D2, LineTR. * remove download scripts for pretrained models. * minor. fix module name for sold2 matcher.
- Loading branch information
Showing
26 changed files
with
1,501 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
*.swp | ||
*.zip | ||
*.tar.gz | ||
*.th | ||
|
||
experiments | ||
build | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,3 +20,6 @@ | |
[submodule "third-party/pytlsd"] | ||
path = third-party/pytlsd | ||
url = [email protected]:iago-suarez/pytlsd.git | ||
[submodule "third-party/pytlbd"] | ||
path = third-party/pytlbd | ||
url = [email protected]:iago-suarez/pytlbd.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from __future__ import division, print_function | ||
import torch | ||
import torch.nn.init | ||
import torch.nn as nn | ||
|
||
|
||
class L2Norm(nn.Module): | ||
def __init__(self): | ||
super(L2Norm,self).__init__() | ||
self.eps = 1e-10 | ||
def forward(self, x): | ||
norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) | ||
x= x / norm.unsqueeze(-1).expand_as(x) | ||
return x | ||
|
||
|
||
class L2Net(nn.Module): | ||
def __init__(self): | ||
super(L2Net, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias = False), | ||
nn.BatchNorm2d(32, affine=False), | ||
nn.ReLU(), | ||
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias = False), | ||
nn.BatchNorm2d(32, affine=False), | ||
nn.ReLU(), | ||
nn.Conv2d(32, 64, kernel_size=(4,3), stride=2, padding=1, bias = False),#3 | ||
nn.BatchNorm2d(64, affine=False), | ||
nn.ReLU(), | ||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias = False), | ||
nn.BatchNorm2d(64, affine=False), | ||
nn.ReLU(), | ||
nn.Conv2d(64, 128, kernel_size=(4,3), stride=2,padding=1, bias = False),#3 | ||
nn.BatchNorm2d(128, affine=False), | ||
nn.ReLU(), | ||
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias = False), | ||
nn.BatchNorm2d(128, affine=False), | ||
nn.ReLU(), | ||
nn.Dropout(0.3), | ||
nn.Conv2d(128, 128, kernel_size=(12,8), bias = False),#8 | ||
nn.BatchNorm2d(128, affine=False), | ||
|
||
) | ||
self.features.apply(weights_init) | ||
return | ||
|
||
def input_norm(self,x): | ||
flat = x.view(x.size(0), -1) | ||
mp = torch.mean(flat, dim=1) | ||
sp = torch.std(flat, dim=1) + 1e-7 | ||
return (x - mp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) | ||
|
||
def forward(self, input): | ||
x_features = self.features(self.input_norm(input)) | ||
x = x_features.view(x_features.size(0), -1) | ||
|
||
return L2Norm()(x) | ||
|
||
def weights_init(m): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.orthogonal_(m.weight.data, gain=0.6) | ||
try: | ||
nn.init.constant_(m.bias.data, 0.01) | ||
|
||
except: | ||
pass | ||
return | ||
|
||
def get_net(): | ||
return L2Net() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .extractor import L2D2Extractor | ||
from .matcher import L2D2Matcher |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import os, sys | ||
import numpy as np | ||
import cv2 | ||
import torch | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
from base_detector import BaseDetector, BaseDetectorOptions | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | ||
import limap.util.io as limapio | ||
|
||
sys.path.append(os.path.dirname(__file__)) | ||
from .RAL_net_cov import get_net | ||
|
||
|
||
class L2D2Extractor(BaseDetector): | ||
def __init__(self, options = BaseDetectorOptions(), device=None): | ||
super(L2D2Extractor, self).__init__(options) | ||
self.mini_batch = 20 | ||
self.device = 'cuda' if device is None else device | ||
ckpt = os.path.join(os.path.dirname(__file__), | ||
'checkpoint_line_descriptor.th') | ||
if not os.path.isfile(ckpt): | ||
self.download_model(ckpt) | ||
self.model = torch.load(ckpt).to(self.device) | ||
self.model.eval() | ||
|
||
def download_model(self, path): | ||
import subprocess | ||
if not os.path.exists(os.path.dirname(path)): | ||
os.makedirs(os.path.dirname(path)) | ||
link = "https://github.com/hichem-abdellali/L2D2/blob/main/IN_OUT_DATA/INPUT_NETWEIGHT/checkpoint_line_descriptor.th?raw=true" | ||
cmd = ["wget", link, "-O", path] | ||
print("Downloading L2D2 model...") | ||
subprocess.run(cmd, check=True) | ||
|
||
def get_module_name(self): | ||
return "l2d2" | ||
|
||
def get_descinfo_fname(self, descinfo_folder, img_id): | ||
fname = os.path.join(descinfo_folder, "descinfo_{0}.npz".format(img_id)) | ||
return fname | ||
|
||
def save_descinfo(self, descinfo_folder, img_id, descinfo): | ||
limapio.check_makedirs(descinfo_folder) | ||
fname = self.get_descinfo_fname(descinfo_folder, img_id) | ||
limapio.save_npz(fname, descinfo) | ||
|
||
def read_descinfo(self, descinfo_folder, img_id): | ||
fname = self.get_descinfo_fname(descinfo_folder, img_id) | ||
descinfo = limapio.read_npz(fname) | ||
return descinfo | ||
|
||
def extract(self, camview, segs): | ||
img = camview.read_image(set_gray=self.set_gray) | ||
descinfo = self.compute_descinfo(img, segs) | ||
return descinfo | ||
|
||
def get_patch(self, img, line): | ||
""" Extract a 48x32 patch around a line [2, 2]. """ | ||
h, w = img.shape | ||
|
||
# Keep a consistent endpoint ordering | ||
if line[1, 1] < line[0, 1]: | ||
line = line[[1, 0]] | ||
|
||
# Get the rotation angle | ||
angle = np.arctan2(line[1, 0] - line[0, 0], line[1, 1] - line[0, 1]) | ||
|
||
# Compute the affine transform to center and rotate the line | ||
midpoint = line.mean(axis=0) | ||
T_midpoint_to_origin = np.array([[1., 0., -midpoint[0]], | ||
[0., 1., -midpoint[1]], | ||
[0., 0., 1.]]) | ||
T_rot = np.array([[np.cos(angle), -np.sin(angle), 0.], | ||
[np.sin(angle), np.cos(angle), 0.], | ||
[0., 0., 1.]]) | ||
T_origin_to_center = np.array([[1., 0., w // 2], | ||
[0., 1., h // 2], | ||
[0., 0., 1.]]) | ||
A = T_origin_to_center @ T_rot @ T_midpoint_to_origin | ||
|
||
# Translate and rotate the image | ||
patch = cv2.warpAffine(img, A[:2], (w, h)) | ||
|
||
# Crop and resize the patch | ||
length = np.linalg.norm(line[0] - line[1]) | ||
new_h = max(int(np.round(length)), 5) # use a minimum height of 5 for short segments | ||
new_w = new_h * 32 // 48 | ||
patch = patch[h // 2 - new_h // 2: h // 2 + new_h // 2, | ||
w // 2 - new_w // 2: w // 2 + new_w // 2] | ||
patch = cv2.resize(patch, (32, 48)) | ||
return patch | ||
|
||
def compute_descinfo(self, img, segs): | ||
""" A desc_info is composed of the following tuple / np arrays: | ||
- the line descriptors [N, 128] | ||
""" | ||
# Extract patches and compute a line descriptor for each patch | ||
lines = segs.reshape(-1, 2, 2) | ||
if len(lines) == 0: | ||
return {'line_descriptors': np.empty((0, 128))} | ||
|
||
patches, line_desc = [], [] | ||
for i, l in enumerate(lines): | ||
patches.append(self.get_patch(img, l)) | ||
|
||
if ((i + 1) % self.mini_batch == 0 | ||
or i == len(lines) - 1): | ||
# Extract the descriptors | ||
patches = torch.tensor(np.array(patches), dtype=torch.float, | ||
device=self.device)[:, None] / 255. | ||
patches = (patches - 0.492967568115862) / 0.272086182765434 | ||
with torch.no_grad(): | ||
line_desc.append(self.model(patches)) | ||
patches = [] | ||
line_desc = torch.cat(line_desc, dim=0) # [n_lines, 128] | ||
return {'line_descriptors': line_desc.cpu().numpy()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os, sys | ||
import numpy as np | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
from base_matcher import BaseMatcher, BaseMatcherOptions | ||
|
||
|
||
class L2D2Matcher(BaseMatcher): | ||
def __init__(self, extractor, options = BaseMatcherOptions()): | ||
super(L2D2Matcher, self).__init__(extractor, options) | ||
|
||
def get_module_name(self): | ||
return "l2d2" | ||
|
||
def check_compatibility(self, extractor): | ||
return extractor.get_module_name() == "l2d2" | ||
|
||
def match_pair(self, descinfo1, descinfo2): | ||
if self.topk == 0: | ||
return self.match_segs_with_descinfo(descinfo1, descinfo2) | ||
else: | ||
return self.match_segs_with_descinfo_topk(descinfo1, descinfo2, topk=self.topk) | ||
|
||
def match_segs_with_descinfo(self, descinfo1, descinfo2): | ||
desc1 = descinfo1['line_descriptors'] | ||
desc2 = descinfo2['line_descriptors'] | ||
|
||
# Default case when an image has no lines | ||
if len(desc1) == 0 or len(desc2) == 0: | ||
return np.empty((0, 2)) | ||
|
||
# Mutual nearest neighbor matching | ||
score_mat = desc1 @ desc2.T | ||
nearest1 = np.argmax(score_mat, axis=1) | ||
nearest2 = np.argmax(score_mat, axis=0) | ||
mutual = nearest2[nearest1] == np.arange(len(desc1)) | ||
nearest1[~mutual] = -1 | ||
|
||
# Transform matches to [n_matches, 2] | ||
id_list_1 = np.arange(0, len(nearest1))[mutual] | ||
id_list_2 = nearest1[mutual] | ||
matches_t = np.stack([id_list_1, id_list_2], 1) | ||
return matches_t | ||
|
||
def match_segs_with_descinfo_topk(self, descinfo1, descinfo2, topk=10): | ||
desc1 = descinfo1['line_descriptors'] | ||
desc2 = descinfo2['line_descriptors'] | ||
|
||
# Default case when an image has no lines | ||
if len(desc1) == 0 or len(desc2) == 0: | ||
return np.empty((0, 2)) | ||
|
||
# Top k nearest neighbor matching | ||
score_mat = desc1 @ desc2.T | ||
matches = np.argsort(score_mat, axis=1)[:, -topk:] | ||
matches = np.flip(matches, axis=1) | ||
|
||
# Transform matches to [n_matches, 2] | ||
n_lines = len(matches) | ||
matches_t = np.stack([np.arange(n_lines).repeat(topk), | ||
matches.flatten()], axis=1) | ||
return matches_t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .extractor import LBDExtractor | ||
from .matcher import LBDMatcher |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os, sys | ||
import numpy as np | ||
import cv2 | ||
|
||
import pytlsd | ||
import pytlbd | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
from base_detector import BaseDetector, BaseDetectorOptions | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | ||
import limap.util.io as limapio | ||
|
||
|
||
def process_pyramid(img, detector, n_levels=5, level_scale=np.sqrt(2), presmooth=True): | ||
octave_img = img.copy() | ||
pre_sigma2 = 0 | ||
cur_sigma2 = 1.0 | ||
pyramid = [] | ||
multiscale_segs = [] | ||
for i in range(n_levels): | ||
increase_sigma = np.sqrt(cur_sigma2 - pre_sigma2) | ||
blurred = cv2.GaussianBlur(octave_img, (5, 5), increase_sigma, borderType=cv2.BORDER_REPLICATE) | ||
pyramid.append(blurred) | ||
|
||
if presmooth: | ||
multiscale_segs.append(detector(blurred)) | ||
else: | ||
multiscale_segs.append(detector(octave_img)) | ||
|
||
# cv2.imshow(f"Mine L{i}", blurred) | ||
# down sample the current octave image to get the next octave image | ||
new_size = (int(octave_img.shape[1] / level_scale), int(octave_img.shape[0] / level_scale)) | ||
octave_img = cv2.resize(blurred, new_size, 0, 0, interpolation=cv2.INTER_NEAREST) | ||
pre_sigma2 = cur_sigma2 | ||
cur_sigma2 = cur_sigma2 * 2 | ||
|
||
return multiscale_segs, pyramid | ||
|
||
|
||
def to_multiscale_lines(lines): | ||
ms_lines = [] | ||
for l in lines.reshape(-1, 4): | ||
ll = np.append(l, [0, np.linalg.norm(l[:2] - l[2:4])]) | ||
ms_lines.append([(0, ll)] + [(i, ll / (i * np.sqrt(2))) for i in range(1, 5)]) | ||
return ms_lines | ||
|
||
|
||
class LBDExtractor(BaseDetector): | ||
def __init__(self, options = BaseDetectorOptions()): | ||
super(LBDExtractor, self).__init__(options) | ||
|
||
def get_module_name(self): | ||
return "lbd" | ||
|
||
def get_descinfo_fname(self, descinfo_folder, img_id): | ||
fname = os.path.join(descinfo_folder, "descinfo_{0}.npz".format(img_id)) | ||
return fname | ||
|
||
def save_descinfo(self, descinfo_folder, img_id, descinfo): | ||
limapio.check_makedirs(descinfo_folder) | ||
fname = self.get_descinfo_fname(descinfo_folder, img_id) | ||
limapio.save_npz(fname, descinfo) | ||
|
||
def read_descinfo(self, descinfo_folder, img_id): | ||
fname = self.get_descinfo_fname(descinfo_folder, img_id) | ||
descinfo = limapio.read_npz(fname) | ||
return descinfo | ||
|
||
def extract(self, camview, segs): | ||
img = camview.read_image(set_gray=self.set_gray) | ||
descinfo = self.compute_descinfo(img, segs) | ||
return descinfo | ||
|
||
def compute_descinfo(self, img, segs): | ||
""" A desc_info is composed of the following tuple / np arrays: | ||
- the multiscale lines [N, 5] containing tuples of (scale, scaled_line) | ||
- the line descriptors [N, dim] | ||
""" | ||
ms_lines = to_multiscale_lines(segs) | ||
_, pyramid = process_pyramid(img, pytlsd.lsd, presmooth=False) | ||
descriptors = pytlbd.lbd_multiscale_pyr(pyramid, ms_lines, 9, 7) | ||
|
||
return {'ms_lines': ms_lines, 'line_descriptors': descriptors} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os, sys | ||
import numpy as np | ||
|
||
import pytlbd | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
from base_matcher import BaseMatcher, BaseMatcherOptions | ||
|
||
class LBDMatcher(BaseMatcher): | ||
def __init__(self, extractor, options = BaseMatcherOptions()): | ||
super(LBDMatcher, self).__init__(extractor, options) | ||
|
||
def get_module_name(self): | ||
return "lbd" | ||
|
||
def check_compatibility(self, extractor): | ||
return extractor.get_module_name() == "lbd" | ||
|
||
def match_pair(self, descinfo1, descinfo2): | ||
if self.topk == 0: | ||
return self.match_segs_with_descinfo(descinfo1, descinfo2) | ||
else: | ||
return self.match_segs_with_descinfo_topk(descinfo1, descinfo2, topk=self.topk) | ||
|
||
def match_segs_with_descinfo(self, descinfo1, descinfo2): | ||
try: | ||
matches = pytlbd.lbd_matching_multiscale( | ||
descinfo1['ms_lines'].tolist(), | ||
descinfo2['ms_lines'].tolist(), | ||
descinfo1['line_descriptors'].tolist(), | ||
descinfo2['line_descriptors'].tolist()) | ||
matches = np.array(matches)[:, :2] | ||
except RuntimeError: | ||
matches = np.zeros((0, 2)) | ||
return matches | ||
|
||
def match_segs_with_descinfo_topk(self, descinfo1, descinfo2, topk=10): | ||
raise NotImplementedError() |
Oops, something went wrong.