Skip to content

Commit

Permalink
Merge pull request #14 from robertknight/char-error-rate
Browse files Browse the repository at this point in the history
Fix char error rate calculation
  • Loading branch information
robertknight authored Mar 3, 2024
2 parents 4caa6fd + 3ab5fb8 commit ed2f430
Showing 1 changed file with 58 additions and 18 deletions.
76 changes: 58 additions & 18 deletions ocrs_models/train_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,24 @@ def __init__(self):
self.char_errors = 0

def update(
self, targets: torch.Tensor, target_lengths: list[int], preds: torch.Tensor
self,
targets: torch.Tensor,
target_lengths: list[int],
preds: torch.Tensor,
pred_lengths: list[int],
):
"""
Update running statistics given targets and predictions for a batch of images.
:param targets: [batch, seq] tensor of target character indices
:param target_lengths: Lengths of target sequences
:param preds: [seq, batch, class] tensor of character predictions
:param pred_lengths: Lengths of predicted sequences
"""

assert len(target_lengths) == targets.size(0)
assert len(pred_lengths) == preds.size(1)

total_chars = sum(target_lengths)
char_errors = 0

Expand All @@ -50,9 +59,9 @@ def update(

alphabet_chars = list(DEFAULT_ALPHABET)

for y, x in zip(targets_list, preds_list):
for y, x, x_len in zip(targets_list, preds_list, pred_lengths):
target_text = decode_text(y, alphabet_chars)
pred_text = ctc_greedy_decode_text(x, alphabet_chars)
pred_text = ctc_greedy_decode_text(x[:x_len], alphabet_chars)
char_errors += levenshtein(target_text, pred_text)

self.total_chars += total_chars
Expand Down Expand Up @@ -111,15 +120,16 @@ def train(
pred_seq = model(img)
batch_loss = loss(pred_seq, text_seq, input_lengths, target_lengths)

stats.update(text_seq, target_lengths, pred_seq)
stats.update(text_seq, target_lengths, pred_seq, input_lengths)

# Preview decoded text for first batch in the dataset.
if batch_idx == 0:
for i in range(min(10, len(text_seq))):
y = text_seq[i]
x = pred_seq[:, i, :].argmax(-1)
x_len = input_lengths[i]
target_text = decode_text(y, list(DEFAULT_ALPHABET))
pred_text = ctc_greedy_decode_text(x, list(DEFAULT_ALPHABET))
pred_text = ctc_greedy_decode_text(x[:x_len], list(DEFAULT_ALPHABET))
print(f'Sample train prediction "{pred_text}" target "{target_text}"')

if math.isnan(batch_loss.item()):
Expand Down Expand Up @@ -182,15 +192,19 @@ def test(
# Predict [seq, batch, class] from [batch, 1, height, width].
pred_seq = model(img)

stats.update(text_seq, target_lengths, pred_seq)
stats.update(text_seq, target_lengths, pred_seq, input_lengths)

# Preview decoded text for first batch in the dataset.
if batch_idx == 0:
for i in range(min(10, len(text_seq))):
y = text_seq[i]
x = pred_seq[:, i, :].argmax(-1)
x_len = input_lengths[i]

target_text = decode_text(y, list(DEFAULT_ALPHABET))
pred_text = ctc_greedy_decode_text(x, list(DEFAULT_ALPHABET))
pred_text = ctc_greedy_decode_text(
x[:x_len], list(DEFAULT_ALPHABET)
)
print(
f'Sample test prediction "{pred_text}" target "{target_text}"'
)
Expand Down Expand Up @@ -242,28 +256,50 @@ def text_len(sample: dict) -> int:
def image_width(sample: dict) -> int:
return sample["image"].shape[-1]

# Factor by which the model's output sequence length is reduced compared to
# the width of the input image.
downsample_factor = 4

# Determine width of batched tensors. We round up the value to reduce the
# variation in tensor sizes across batches. Having too many distinct tensor
# sizes has been observed to lead to memory fragmentation and ultimately
# memory exhaustion when training on GPUs.
max_img_len = round_up(max([image_width(s) for s in samples]), 250)
max_text_len = round_up(max([text_len(s) for s in samples]), 250)
img_width_step = 256
max_img_width = max(image_width(s) for s in samples)
max_img_width = round_up(max_img_width, img_width_step)

max_text_len = max(text_len(s) for s in samples)
max_text_len = round_up(max_text_len, img_width_step // downsample_factor)

# Remove samples where the target text is incompatible with the width of
# the image after downsampling by the model's CNN, which reduces the
# width by 4x.
# width by `downsample_factor`.
samples = [
s
for s in samples
if ctc_input_and_target_compatible(image_width(s) // 4, s["text_seq"])
if ctc_input_and_target_compatible(
image_width(s) // downsample_factor, s["text_seq"]
)
]

for s in samples:
s["text_len"] = text_len(s)
s["text_seq"] = F.pad(s["text_seq"], [0, max_text_len - s["text_len"]])
for sample in samples:
text_pad_value = 0 # CTC blank label
sample["text_len"] = text_len(sample)
sample["text_seq"] = F.pad(
sample["text_seq"],
[0, max_text_len - sample["text_len"]],
mode="constant",
value=text_pad_value,
)

s["image_width"] = image_width(s)
s["image"] = F.pad(s["image"], [0, max_img_len - s["image_width"]])
image_pad_value = 0.0 # Grey, since image values are in [-0.5, 0.5]
sample["image_width"] = image_width(sample)
sample["image"] = F.pad(
sample["image"],
[0, max_img_width - sample["image_width"]],
mode="constant",
value=image_pad_value,
)

return default_collate(samples)

Expand All @@ -281,6 +317,7 @@ def main():
parser.add_argument("--batch-size", type=int, default=20)
parser.add_argument("--checkpoint", type=str, help="Model checkpoint to load")
parser.add_argument("--export", type=str, help="Export model to ONNX format")
parser.add_argument("--lr", type=float, help="Initial learning rate")
parser.add_argument(
"--max-epochs", type=int, help="Maximum number of epochs to train for"
)
Expand Down Expand Up @@ -341,9 +378,10 @@ def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = RecognitionModel(alphabet=DEFAULT_ALPHABET).to(device)

optimizer = torch.optim.Adam(model.parameters())
initial_lr = args.lr or 1e-3 # 1e-3 is the Adam default
optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=0.1, patience=3, verbose=True
optimizer, factor=0.1, patience=3
)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
Expand Down Expand Up @@ -407,6 +445,8 @@ def main():

scheduler.step(val_loss)

print(f"Current learning rate {scheduler.get_last_lr()}")

if enable_wandb:
wandb.log(
{
Expand Down

0 comments on commit ed2f430

Please sign in to comment.