From 2cfae436adf6a2365b1969a0eade7cde2ce2c16e Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 25 Feb 2024 13:03:33 +0000 Subject: [PATCH 1/3] Take into account input padding when decoding outputs from model When decoding outputs from a model to preview or compute the char error rate, the part of the output that corresponds to the padding region of the input needs to be ignored. The loss function only takes into consideration the un-padded part of the input. This fixes an issue where predictions had spurious extra characters on the end and the calculated char error rate was higher than it should have been. --- ocrs_models/train_rec.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/ocrs_models/train_rec.py b/ocrs_models/train_rec.py index 6a04ab8..0f4fca3 100644 --- a/ocrs_models/train_rec.py +++ b/ocrs_models/train_rec.py @@ -27,7 +27,11 @@ def __init__(self): self.char_errors = 0 def update( - self, targets: torch.Tensor, target_lengths: list[int], preds: torch.Tensor + self, + targets: torch.Tensor, + target_lengths: list[int], + preds: torch.Tensor, + pred_lengths: list[int], ): """ Update running statistics given targets and predictions for a batch of images. @@ -35,7 +39,12 @@ def update( :param targets: [batch, seq] tensor of target character indices :param target_lengths: Lengths of target sequences :param preds: [seq, batch, class] tensor of character predictions + :param pred_lengths: Lengths of predicted sequences """ + + assert len(target_lengths) == targets.size(0) + assert len(pred_lengths) == preds.size(1) + total_chars = sum(target_lengths) char_errors = 0 @@ -50,9 +59,9 @@ def update( alphabet_chars = list(DEFAULT_ALPHABET) - for y, x in zip(targets_list, preds_list): + for y, x, x_len in zip(targets_list, preds_list, pred_lengths): target_text = decode_text(y, alphabet_chars) - pred_text = ctc_greedy_decode_text(x, alphabet_chars) + pred_text = ctc_greedy_decode_text(x[:x_len], alphabet_chars) char_errors += levenshtein(target_text, pred_text) self.total_chars += total_chars @@ -111,15 +120,16 @@ def train( pred_seq = model(img) batch_loss = loss(pred_seq, text_seq, input_lengths, target_lengths) - stats.update(text_seq, target_lengths, pred_seq) + stats.update(text_seq, target_lengths, pred_seq, input_lengths) # Preview decoded text for first batch in the dataset. if batch_idx == 0: for i in range(min(10, len(text_seq))): y = text_seq[i] x = pred_seq[:, i, :].argmax(-1) + x_len = input_lengths[i] target_text = decode_text(y, list(DEFAULT_ALPHABET)) - pred_text = ctc_greedy_decode_text(x, list(DEFAULT_ALPHABET)) + pred_text = ctc_greedy_decode_text(x[:x_len], list(DEFAULT_ALPHABET)) print(f'Sample train prediction "{pred_text}" target "{target_text}"') if math.isnan(batch_loss.item()): @@ -182,15 +192,19 @@ def test( # Predict [seq, batch, class] from [batch, 1, height, width]. pred_seq = model(img) - stats.update(text_seq, target_lengths, pred_seq) + stats.update(text_seq, target_lengths, pred_seq, input_lengths) # Preview decoded text for first batch in the dataset. if batch_idx == 0: for i in range(min(10, len(text_seq))): y = text_seq[i] x = pred_seq[:, i, :].argmax(-1) + x_len = input_lengths[i] + target_text = decode_text(y, list(DEFAULT_ALPHABET)) - pred_text = ctc_greedy_decode_text(x, list(DEFAULT_ALPHABET)) + pred_text = ctc_greedy_decode_text( + x[:x_len], list(DEFAULT_ALPHABET) + ) print( f'Sample test prediction "{pred_text}" target "{target_text}"' ) From 1693ab85f15906cfd2aea461d5f9dbed8f4a4f36 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 25 Feb 2024 08:46:07 +0000 Subject: [PATCH 2/3] Reduce amount of padding in target vectors for recognition training Since the model output downsamples the input image width by 4x, the size increment for target vectors can be smaller than the size increment for inputs. Also make the image width step a power of 2, since such sizes are generally more optimal in various layers of the ML runtime. --- ocrs_models/train_rec.py | 40 +++++++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/ocrs_models/train_rec.py b/ocrs_models/train_rec.py index 0f4fca3..39bd404 100644 --- a/ocrs_models/train_rec.py +++ b/ocrs_models/train_rec.py @@ -256,28 +256,50 @@ def text_len(sample: dict) -> int: def image_width(sample: dict) -> int: return sample["image"].shape[-1] + # Factor by which the model's output sequence length is reduced compared to + # the width of the input image. + downsample_factor = 4 + # Determine width of batched tensors. We round up the value to reduce the # variation in tensor sizes across batches. Having too many distinct tensor # sizes has been observed to lead to memory fragmentation and ultimately # memory exhaustion when training on GPUs. - max_img_len = round_up(max([image_width(s) for s in samples]), 250) - max_text_len = round_up(max([text_len(s) for s in samples]), 250) + img_width_step = 256 + max_img_width = max(image_width(s) for s in samples) + max_img_width = round_up(max_img_width, img_width_step) + + max_text_len = max(text_len(s) for s in samples) + max_text_len = round_up(max_text_len, img_width_step // downsample_factor) # Remove samples where the target text is incompatible with the width of # the image after downsampling by the model's CNN, which reduces the - # width by 4x. + # width by `downsample_factor`. samples = [ s for s in samples - if ctc_input_and_target_compatible(image_width(s) // 4, s["text_seq"]) + if ctc_input_and_target_compatible( + image_width(s) // downsample_factor, s["text_seq"] + ) ] - for s in samples: - s["text_len"] = text_len(s) - s["text_seq"] = F.pad(s["text_seq"], [0, max_text_len - s["text_len"]]) + for sample in samples: + text_pad_value = 0 # CTC blank label + sample["text_len"] = text_len(sample) + sample["text_seq"] = F.pad( + sample["text_seq"], + [0, max_text_len - sample["text_len"]], + mode="constant", + value=text_pad_value, + ) - s["image_width"] = image_width(s) - s["image"] = F.pad(s["image"], [0, max_img_len - s["image_width"]]) + image_pad_value = 0.0 # Grey, since image values are in [-0.5, 0.5] + sample["image_width"] = image_width(sample) + sample["image"] = F.pad( + sample["image"], + [0, max_img_width - sample["image_width"]], + mode="constant", + value=image_pad_value, + ) return default_collate(samples) From 3ab5fb84c93a2f4ccfa03dcae61fe5067a875d9f Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 25 Feb 2024 08:31:05 +0000 Subject: [PATCH 3/3] Add `--lr` flag to set initial LR for recognition training - Add `lr` flag to set initial LR, making it easier to experiment with different values. - Remove use of deprecated `verbose` kwarg for `ReduceLROnPlateau` and instead of `get_last_lr` to log the learning rate. --- ocrs_models/train_rec.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ocrs_models/train_rec.py b/ocrs_models/train_rec.py index 39bd404..7280d46 100644 --- a/ocrs_models/train_rec.py +++ b/ocrs_models/train_rec.py @@ -317,6 +317,7 @@ def main(): parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--checkpoint", type=str, help="Model checkpoint to load") parser.add_argument("--export", type=str, help="Export model to ONNX format") + parser.add_argument("--lr", type=float, help="Initial learning rate") parser.add_argument( "--max-epochs", type=int, help="Maximum number of epochs to train for" ) @@ -377,9 +378,10 @@ def main(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = RecognitionModel(alphabet=DEFAULT_ALPHABET).to(device) - optimizer = torch.optim.Adam(model.parameters()) + initial_lr = args.lr or 1e-3 # 1e-3 is the Adam default + optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, factor=0.1, patience=3, verbose=True + optimizer, factor=0.1, patience=3 ) total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -443,6 +445,8 @@ def main(): scheduler.step(val_loss) + print(f"Current learning rate {scheduler.get_last_lr()}") + if enable_wandb: wandb.log( {