diff --git a/lightglue/lightglue.py b/lightglue/lightglue.py index fcbb7ee..083f926 100644 --- a/lightglue/lightglue.py +++ b/lightglue/lightglue.py @@ -104,6 +104,8 @@ def __init__(self, allow_flash: bool) -> None: torch.backends.cuda.enable_flash_sdp(allow_flash) def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if q.shape[-2] == 0 or k.shape[-2] == 0: + return q.new_zeros((*q.shape[:-1], v.shape[-1])) if self.enable_flash and q.device.type == "cuda": # use torch 2.0 scaled_dot_product_attention with flash if self.has_sdp: @@ -518,6 +520,8 @@ def _forward(self, data: dict) -> dict: prune1 = torch.ones_like(ind1) token0, token1 = None, None for i in range(self.conf.n_layers): + if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints + break desc0, desc1 = self.transformers[i]( desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1 ) @@ -526,7 +530,7 @@ def _forward(self, data: dict) -> dict: if do_early_stop: token0, token1 = self.token_confidence[i](desc0, desc1) - if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n): + if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n): break if do_point_pruning and desc0.shape[-2] > pruning_th: scores0 = self.log_assignment[i].get_matchability(desc0) @@ -545,7 +549,29 @@ def _forward(self, data: dict) -> dict: encoding1 = encoding1.index_select(-2, keep1) prune1[:, ind1] += 1 - desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] + if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints + m0 = desc0.new_full((b, m), -1, dtype=torch.long) + m1 = desc1.new_full((b, n), -1, dtype=torch.long) + mscores0 = desc0.new_zeros((b, m)) + mscores1 = desc1.new_zeros((b, n)) + matches = desc0.new_empty((b, 0, 2), dtype=torch.long) + mscores = desc0.new_empty((b, 0)) + if not do_point_pruning: + prune0 = torch.ones_like(mscores0) * self.conf.n_layers + prune1 = torch.ones_like(mscores1) * self.conf.n_layers + return { + "matches0": m0, + "matches1": m1, + "matching_scores0": mscores0, + "matching_scores1": mscores1, + "stop": i + 1, + "matches": matches, + "scores": mscores, + "prune0": prune0, + "prune1": prune1, + } + + desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding scores, _ = self.log_assignment[i](desc0, desc1) m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) matches, mscores = [], [] @@ -574,7 +600,7 @@ def _forward(self, data: dict) -> dict: prune0 = torch.ones_like(mscores0) * self.conf.n_layers prune1 = torch.ones_like(mscores1) * self.conf.n_layers - pred = { + return { "matches0": m0, "matches1": m1, "matching_scores0": mscores0, @@ -586,8 +612,6 @@ def _forward(self, data: dict) -> dict: "prune1": prune1, } - return pred - def confidence_threshold(self, layer_index: int) -> float: """scaled confidence threshold""" threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)