Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix char error rate calculation #14

Merged
merged 3 commits into from
Mar 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions ocrs_models/train_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,24 @@ def __init__(self):
self.char_errors = 0

def update(
self, targets: torch.Tensor, target_lengths: list[int], preds: torch.Tensor
self,
targets: torch.Tensor,
target_lengths: list[int],
preds: torch.Tensor,
pred_lengths: list[int],
):
"""
Update running statistics given targets and predictions for a batch of images.

:param targets: [batch, seq] tensor of target character indices
:param target_lengths: Lengths of target sequences
:param preds: [seq, batch, class] tensor of character predictions
:param pred_lengths: Lengths of predicted sequences
"""

assert len(target_lengths) == targets.size(0)
assert len(pred_lengths) == preds.size(1)

total_chars = sum(target_lengths)
char_errors = 0

Expand All @@ -50,9 +59,9 @@ def update(

alphabet_chars = list(DEFAULT_ALPHABET)

for y, x in zip(targets_list, preds_list):
for y, x, x_len in zip(targets_list, preds_list, pred_lengths):
target_text = decode_text(y, alphabet_chars)
pred_text = ctc_greedy_decode_text(x, alphabet_chars)
pred_text = ctc_greedy_decode_text(x[:x_len], alphabet_chars)
char_errors += levenshtein(target_text, pred_text)

self.total_chars += total_chars
Expand Down Expand Up @@ -111,15 +120,16 @@ def train(
pred_seq = model(img)
batch_loss = loss(pred_seq, text_seq, input_lengths, target_lengths)

stats.update(text_seq, target_lengths, pred_seq)
stats.update(text_seq, target_lengths, pred_seq, input_lengths)

# Preview decoded text for first batch in the dataset.
if batch_idx == 0:
for i in range(min(10, len(text_seq))):
y = text_seq[i]
x = pred_seq[:, i, :].argmax(-1)
x_len = input_lengths[i]
target_text = decode_text(y, list(DEFAULT_ALPHABET))
pred_text = ctc_greedy_decode_text(x, list(DEFAULT_ALPHABET))
pred_text = ctc_greedy_decode_text(x[:x_len], list(DEFAULT_ALPHABET))
print(f'Sample train prediction "{pred_text}" target "{target_text}"')

if math.isnan(batch_loss.item()):
Expand Down Expand Up @@ -182,15 +192,19 @@ def test(
# Predict [seq, batch, class] from [batch, 1, height, width].
pred_seq = model(img)

stats.update(text_seq, target_lengths, pred_seq)
stats.update(text_seq, target_lengths, pred_seq, input_lengths)

# Preview decoded text for first batch in the dataset.
if batch_idx == 0:
for i in range(min(10, len(text_seq))):
y = text_seq[i]
x = pred_seq[:, i, :].argmax(-1)
x_len = input_lengths[i]

target_text = decode_text(y, list(DEFAULT_ALPHABET))
pred_text = ctc_greedy_decode_text(x, list(DEFAULT_ALPHABET))
pred_text = ctc_greedy_decode_text(
x[:x_len], list(DEFAULT_ALPHABET)
)
print(
f'Sample test prediction "{pred_text}" target "{target_text}"'
)
Expand Down Expand Up @@ -242,28 +256,50 @@ def text_len(sample: dict) -> int:
def image_width(sample: dict) -> int:
return sample["image"].shape[-1]

# Factor by which the model's output sequence length is reduced compared to
# the width of the input image.
downsample_factor = 4

# Determine width of batched tensors. We round up the value to reduce the
# variation in tensor sizes across batches. Having too many distinct tensor
# sizes has been observed to lead to memory fragmentation and ultimately
# memory exhaustion when training on GPUs.
max_img_len = round_up(max([image_width(s) for s in samples]), 250)
max_text_len = round_up(max([text_len(s) for s in samples]), 250)
img_width_step = 256
max_img_width = max(image_width(s) for s in samples)
max_img_width = round_up(max_img_width, img_width_step)

max_text_len = max(text_len(s) for s in samples)
max_text_len = round_up(max_text_len, img_width_step // downsample_factor)

# Remove samples where the target text is incompatible with the width of
# the image after downsampling by the model's CNN, which reduces the
# width by 4x.
# width by `downsample_factor`.
samples = [
s
for s in samples
if ctc_input_and_target_compatible(image_width(s) // 4, s["text_seq"])
if ctc_input_and_target_compatible(
image_width(s) // downsample_factor, s["text_seq"]
)
]

for s in samples:
s["text_len"] = text_len(s)
s["text_seq"] = F.pad(s["text_seq"], [0, max_text_len - s["text_len"]])
for sample in samples:
text_pad_value = 0 # CTC blank label
sample["text_len"] = text_len(sample)
sample["text_seq"] = F.pad(
sample["text_seq"],
[0, max_text_len - sample["text_len"]],
mode="constant",
value=text_pad_value,
)

s["image_width"] = image_width(s)
s["image"] = F.pad(s["image"], [0, max_img_len - s["image_width"]])
image_pad_value = 0.0 # Grey, since image values are in [-0.5, 0.5]
sample["image_width"] = image_width(sample)
sample["image"] = F.pad(
sample["image"],
[0, max_img_width - sample["image_width"]],
mode="constant",
value=image_pad_value,
)

return default_collate(samples)

Expand All @@ -281,6 +317,7 @@ def main():
parser.add_argument("--batch-size", type=int, default=20)
parser.add_argument("--checkpoint", type=str, help="Model checkpoint to load")
parser.add_argument("--export", type=str, help="Export model to ONNX format")
parser.add_argument("--lr", type=float, help="Initial learning rate")
parser.add_argument(
"--max-epochs", type=int, help="Maximum number of epochs to train for"
)
Expand Down Expand Up @@ -341,9 +378,10 @@ def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = RecognitionModel(alphabet=DEFAULT_ALPHABET).to(device)

optimizer = torch.optim.Adam(model.parameters())
initial_lr = args.lr or 1e-3 # 1e-3 is the Adam default
optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=0.1, patience=3, verbose=True
optimizer, factor=0.1, patience=3
)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
Expand Down Expand Up @@ -407,6 +445,8 @@ def main():

scheduler.step(val_loss)

print(f"Current learning rate {scheduler.get_last_lr()}")

if enable_wandb:
wandb.log(
{
Expand Down
Loading