diff --git a/models/matching_module.py b/models/matching_module.py index a25e8df..985f777 100644 --- a/models/matching_module.py +++ b/models/matching_module.py @@ -131,7 +131,7 @@ def on_validation_epoch_end(self): gc.collect() def configure_optimizers(self): - optimizer = torch.optim.Adam(self.superglue.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.superglue.parameters()), lr=self.config['lr']) scheduler = torch.optim.lr_scheduler.StepLR( optimizer=optimizer, step_size=1,