Skip to content

Commit

Permalink
Some "light" linting.
Browse files Browse the repository at this point in the history
Fixes #557.
  • Loading branch information
mittagessen committed Dec 11, 2023
1 parent 4aaac7b commit 56d84b0
Show file tree
Hide file tree
Showing 21 changed files with 121 additions and 86 deletions.
9 changes: 3 additions & 6 deletions kraken/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,16 @@
import torch
import logging
import dataclasses
import numpy as np

from PIL import Image
from bidi.algorithm import get_display

from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Literal
from typing import Optional, Literal

from kraken import rpred
from kraken.containers import Segmentation, BaselineOCRRecord
from kraken.lib.codec import PytorchCodec
from kraken.lib.xml import XMLPage
from kraken.lib.models import TorchSeqRecognizer
from kraken.lib.exceptions import KrakenInputException, KrakenEncodeException
from kraken.lib.segmentation import compute_polygon_section

logger = logging.getLogger('kraken')

Expand Down Expand Up @@ -95,6 +90,8 @@ def forced_align(doc: Segmentation, model: TorchSeqRecognizer, base_dir: Optiona
at:
https://github.com/pytorch/audio/blob/main/examples/tutorials/forced_alignment_tutorial.py
"""


@dataclass
class Point:
token_index: int
Expand Down
4 changes: 2 additions & 2 deletions kraken/blla.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch.nn.functional as F
import torchvision.transforms as tf

from typing import Optional, Dict, Callable, Union, List, Any, Tuple, Literal
from typing import Optional, Dict, Callable, Union, List, Any, Literal

from scipy.ndimage import gaussian_filter
from skimage.filters import sobel
Expand Down Expand Up @@ -415,4 +415,4 @@ def segment(im: PIL.Image.Image,
lines=blls,
regions=regions,
script_detection=script_detection,
line_orders=[order])
line_orders=[order] if order else [])
2 changes: 1 addition & 1 deletion kraken/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class Segmentation:
script_detection: bool
lines: List[Union[BaselineLine, BBoxLine]]
regions: Dict[str, List[Region]]
line_orders: Optional[List[List[int]]] = None
line_orders: List[List[int]]

def __post_init__(self):
if not self.regions:
Expand Down
7 changes: 3 additions & 4 deletions kraken/ketos/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
from torch.utils.data import DataLoader

from kraken.serialization import render_report
from kraken.lib import models
from kraken.lib import models, util
from kraken.lib.xml import XMLPage
from kraken.lib.dataset import (global_align, compute_confusions,
PolygonGTDataset, GroundTruthDataset,
Expand All @@ -419,7 +419,6 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,

test_set = list(test_set)


if evaluation_files:
test_set.extend(evaluation_files)

Expand All @@ -445,7 +444,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
force_binarization = False
if repolygonize:
logger.warning('Repolygonization enabled in `path` mode. Will be ignored.')
test_set = [{'image': img} for img in test_set]
test_set = [{'line': util.parse_gt_path(img)} for img in test_set]
valid_norm = True

if len(test_set) == 0:
Expand Down Expand Up @@ -480,7 +479,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
ds_loader = DataLoader(ds,
batch_size=batch_size,
num_workers=workers,
pin_memory=True,
pin_memory=pin_ds_mem,
collate_fn=collate_sequences)

with KrakenProgressBar() as progress:
Expand Down
11 changes: 5 additions & 6 deletions kraken/ketos/ro.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import logging

from PIL import Image
from typing import Dict

from kraken.lib.exceptions import KrakenInputException
from kraken.lib.default_specs import READING_ORDER_HYPER_PARAMS

from kraken.ketos.util import _validate_manifests, _expand_gt, message, to_ptl_device
Expand All @@ -36,6 +34,7 @@
# raise default max image size to 20k * 20k pixels
Image.MAX_IMAGE_PIXELS = 20000 ** 2


@click.command('rotrain')
@click.pass_context
@click.option('-B', '--batch-size', show_default=True, type=click.INT,
Expand Down Expand Up @@ -156,14 +155,13 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag,

from kraken.lib.ro import ROModel
from kraken.lib.train import KrakenTrainer
from kraken.lib.progress import KrakenProgressBar

if not (0 <= freq <= 1) and freq % 1.0 != 0:
raise click.BadOptionUsage('freq', 'freq needs to be either in the interval [0,1.0] or a positive integer.')

if pl_logger == 'tensorboard':
try:
import tensorboard
import tensorboard # NOQA
except ImportError:
raise click.BadOptionUsage('logger', 'tensorboard logger needs the `tensorboard` package installed.')

Expand Down Expand Up @@ -191,7 +189,9 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag,
'step_size': step_size,
'rop_patience': sched_patience,
'cos_t_max': cos_max,
'pl_logger': pl_logger,})
'pl_logger': pl_logger,
}
)

# disable automatic partition when given evaluation set explicitly
if evaluation_files:
Expand Down Expand Up @@ -281,7 +281,6 @@ def roadd(ctx, output, ro_model, seg_model):
"""
from kraken.lib import vgsl
from kraken.lib.ro import ROModel
from kraken.lib.train import KrakenTrainer

message(f'Adding {ro_model} reading order model to {seg_model}.')
ro_net = ROModel.load_from_checkpoint(ro_model)
Expand Down
4 changes: 3 additions & 1 deletion kraken/ketos/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from PIL import Image

from typing import Dict

from kraken.lib.exceptions import KrakenInputException
from kraken.lib.default_specs import SEGMENTATION_HYPER_PARAMS, SEGMENTATION_SPEC

Expand Down Expand Up @@ -232,7 +234,6 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs,
from threadpoolctl import threadpool_limits

from kraken.lib.train import SegmentationModel, KrakenTrainer
from kraken.lib.progress import KrakenProgressBar

if resize != 'fail' and not load:
raise click.BadOptionUsage('resize', 'resize option requires loading an existing model')
Expand Down Expand Up @@ -431,6 +432,7 @@ def segtest(ctx, model, evaluation_files, device, workers, threads, threshold,
import torch
import torch.nn.functional as F

from kraken.lib.progress import KrakenProgressBar
from kraken.lib.train import BaselineSet, ImageInputTransforms
from kraken.lib.vgsl import TorchVGSLModel

Expand Down
4 changes: 2 additions & 2 deletions kraken/lib/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def encode(self, s: str) -> IntTensor:
idx += len(code)
encodable_suffix = True
break

if not encodable_suffix and s[idx] in self.c2l:
labels.extend(self.c2l[s[idx]])
idx += 1
encodable_suffix = True

if not encodable_suffix:
if self.strict:
raise KrakenEncodeException(f'Non-encodable sequence {s[idx:idx+5]}... encountered.')
Expand Down
8 changes: 4 additions & 4 deletions kraken/lib/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
Top-level module containing datasets for recognition and segmentation training.
"""
from .recognition import ArrowIPCRecognitionDataset, PolygonGTDataset, GroundTruthDataset # NOQA
from .segmentation import BaselineSet # NOQA
from .ro import PairWiseROSet, PageWiseROSet #NOQA
from .utils import ImageInputTransforms, collate_sequences, global_align, compute_confusions # NOQA
from .recognition import ArrowIPCRecognitionDataset, PolygonGTDataset, GroundTruthDataset # NOQA
from .segmentation import BaselineSet # NOQA
from .ro import PairWiseROSet, PageWiseROSet # NOQA
from .utils import ImageInputTransforms, collate_sequences, global_align, compute_confusions # NOQA
23 changes: 17 additions & 6 deletions kraken/lib/dataset/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import torch
import traceback
import dataclasses
import numpy as np
import pyarrow as pa

Expand All @@ -28,7 +29,7 @@
from torchvision import transforms
from collections import Counter
from torch.utils.data import Dataset
from typing import Dict, List, Tuple, Callable, Optional, Any, Union, Literal
from typing import List, Tuple, Callable, Optional, Any, Union, Literal

from kraken.containers import BaselineLine, BBoxLine, Segmentation
from kraken.lib.util import is_bitonal
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(self):
def __call__(self, image):
return self._transforms(image=image)


class ArrowIPCRecognitionDataset(Dataset):
"""
Dataset for training a recognition model from a precompiled dataset in
Expand Down Expand Up @@ -181,7 +183,7 @@ def add(self, file: Union[str, PathLike]) -> None:
mask = np.ones(len(ds_table), dtype=bool)
for index in range(len(ds_table)):
try:
text = self._apply_text_transform(ds_table.column('lines')[index].as_py(),)
self._apply_text_transform(ds_table.column('lines')[index].as_py(),)
except KrakenInputException:
mask[index] = False
continue
Expand Down Expand Up @@ -335,7 +337,7 @@ def add(self,
self.add_line(line)
if page:
self.add_page(page)
if not (line and page):
if not (line or page):
raise ValueError('Neither line nor page data provided in dataset builder')

def add_page(self, page: Segmentation):
Expand Down Expand Up @@ -379,7 +381,7 @@ def add_line(self, line: BaselineLine):
if not line.boundary:
raise ValueError('No boundary given for line')

self._images.append((line.image, line.baseline, line.boundary))
self._images.append((line.imagename, line.baseline, line.boundary))
self._gt.append(text)
self.alphabet.update(text)

Expand Down Expand Up @@ -412,8 +414,17 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
im = item[0][0]
if not isinstance(im, Image.Image):
im = Image.open(im)
im, _ = next(extract_polygons(im, {'type': 'baselines',
'lines': [{'baseline': item[0][1], 'boundary': item[0][2]}]}))
im, _ = next(extract_polygons(im,
Segmentation(type='baselines',
imagename=item[0][0],
text_direction='horizontal-lr',
lines=[BaselineLine('id_0',
baseline=item[0][1],
boundary=item[0][2])],
script_detection=True,
regions={},
line_orders=[])
))
im = self.transforms(im)
if im.shape[0] == 3:
im_mode = 'RGB'
Expand Down
23 changes: 8 additions & 15 deletions kraken/lib/dataset/ro.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,13 @@
"""
Utility functions for data loading and training of VGSL networks.
"""
import json
import torch
import traceback
import numpy as np
import torch.nn.functional as F
import shapely.geometry as geom

from math import factorial
from os import path, PathLike
from PIL import Image
from shapely.ops import split
from itertools import groupby
from torchvision import transforms
from collections import defaultdict
from os import PathLike
from torch.utils.data import Dataset
from typing import Dict, List, Tuple, Sequence, Callable, Any, Union, Literal, Optional
from typing import Dict, Sequence, Union, Literal, Optional

from kraken.lib.xml import XMLPage

Expand Down Expand Up @@ -112,8 +103,9 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None,
torch.tensor(line_center, dtype=torch.float), # line center
torch.tensor(line_coords[0, :], dtype=torch.float), # start_point coord
torch.tensor(line_coords[-1, :], dtype=torch.float), # end point coord)
))
}
)
)
}
sorted_lines.append(line_data)
if len(sorted_lines) > 1:
self.data.append(sorted_lines)
Expand Down Expand Up @@ -212,8 +204,9 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None,
torch.tensor(line_center, dtype=torch.float), # line center
torch.tensor(line_coords[0, :], dtype=torch.float), # start_point coord
torch.tensor(line_coords[-1, :], dtype=torch.float), # end point coord)
))
}
)
)
}
sorted_lines.append(line_data)
if len(sorted_lines) > 1:
self.data.append(sorted_lines)
Expand Down
7 changes: 2 additions & 5 deletions kraken/lib/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,25 @@
"""
Utility functions for data loading and training of VGSL networks.
"""
import json
import torch
import traceback
import numpy as np
import torch.nn.functional as F
import shapely.geometry as geom

from os import path, PathLike
from PIL import Image
from shapely.ops import split
from itertools import groupby
from torchvision import transforms
from collections import defaultdict
from torch.utils.data import Dataset
from typing import Dict, List, Tuple, Sequence, Callable, Any, Union, Literal, Optional
from typing import Dict, Tuple, Sequence, Callable, Any, Union, Literal, Optional

from skimage.draw import polygon

from kraken.containers import Segmentation
from kraken.lib.xml import XMLPage

from kraken.lib.exceptions import KrakenInputException

__all__ = ['BaselineSet']

Expand Down Expand Up @@ -160,7 +157,7 @@ def add(self, doc: Union[Segmentation, XMLPage]):
self.class_mapping['regions'][reg_type] = self.num_classes - 1

self.targets.append({'baselines': baselines_, 'regions': regions_})
self.imgs.append(image)
self.imgs.append(doc.imagename)

def __getitem__(self, idx):
im = self.imgs[idx]
Expand Down
1 change: 0 additions & 1 deletion kraken/lib/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from collections import Counter
from typing import Dict, List, Tuple, Sequence, Any, Union

from kraken.lib.models import TorchSeqRecognizer
from kraken.lib.exceptions import KrakenInputException
from kraken.lib.lineest import CenterNormalizer

Expand Down
Loading

0 comments on commit 56d84b0

Please sign in to comment.