Skip to content

Commit

Permalink
Add --lr flag to set initial LR for recognition training
Browse files Browse the repository at this point in the history
 - Add `lr` flag to set initial LR, making it easier to experiment with
   different values.
 - Remove use of deprecated `verbose` kwarg for `ReduceLROnPlateau` and
   instead of `get_last_lr` to log the learning rate.
  • Loading branch information
robertknight committed Mar 3, 2024
1 parent 1693ab8 commit 3ab5fb8
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions ocrs_models/train_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def main():
parser.add_argument("--batch-size", type=int, default=20)
parser.add_argument("--checkpoint", type=str, help="Model checkpoint to load")
parser.add_argument("--export", type=str, help="Export model to ONNX format")
parser.add_argument("--lr", type=float, help="Initial learning rate")
parser.add_argument(
"--max-epochs", type=int, help="Maximum number of epochs to train for"
)
Expand Down Expand Up @@ -377,9 +378,10 @@ def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = RecognitionModel(alphabet=DEFAULT_ALPHABET).to(device)

optimizer = torch.optim.Adam(model.parameters())
initial_lr = args.lr or 1e-3 # 1e-3 is the Adam default
optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=0.1, patience=3, verbose=True
optimizer, factor=0.1, patience=3
)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
Expand Down Expand Up @@ -443,6 +445,8 @@ def main():

scheduler.step(val_loss)

print(f"Current learning rate {scheduler.get_last_lr()}")

if enable_wandb:
wandb.log(
{
Expand Down

0 comments on commit 3ab5fb8

Please sign in to comment.