Skip to content

Commit

Permalink
Add repository tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jan 4, 2024
1 parent a21f1e3 commit 5a14092
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 26 deletions.
3 changes: 2 additions & 1 deletion kraken/lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
from typing import Callable, Dict, Optional, Sequence, Union, Any, Literal, TYPE_CHECKING
from pytorch_lightning.callbacks import Callback, EarlyStopping, BaseFinetuning, LearningRateMonitor

from kraken.containers import Segmentation, XMLPage
from kraken.containers import Segmentation
from kraken.lib import models, vgsl, default_specs, progress
from kraken.lib.xml import XMLPage
from kraken.lib.util import make_printable, parse_gt_path
from kraken.lib.codec import PytorchCodec
from kraken.lib.dataset import (ArrowIPCRecognitionDataset, BaselineSet,
Expand Down
24 changes: 12 additions & 12 deletions kraken/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,19 @@ def get_description(model_id: str, callback: Callable[..., Any] = lambda: None)
raise KrakenRepoException(msg)
meta_json = None
for file in record['files']:
if file['filename'] == 'metadata.json':
callback()
r = requests.get(file['links']['download'])
r.raise_for_status()
callback()
try:
meta_json = r.json()
except Exception:
msg = f'Metadata for \'{record["metadata"]["title"]}\' ({record["metadata"]["doi"]}) not in JSON format'
logger.error(msg)
raise KrakenRepoException(msg)
for file in record['files']:
if file['key'] == 'metadata.json':
callback()
r = requests.get(file['links']['self'])
r.raise_for_status()
try:
meta_json = r.json()
except Exception:
msg = f'Metadata for \'{record["metadata"]["title"]}\' ({record["metadata"]["doi"]}) not in JSON format'
logger.error(msg)
raise KrakenRepoException(msg)
if not meta_json:
msg = 'Mo metadata.jsn found for \'{}\' ({})'.format(record['metadata']['title'], record['metadata']['doi'])
msg = 'Mo metadata.json found for \'{}\' ({})'.format(record['metadata']['title'], record['metadata']['doi'])
logger.error(msg)
raise KrakenRepoException(msg)
# merge metadata.json into DataCite
Expand Down
19 changes: 10 additions & 9 deletions kraken/rpred.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
from collections import defaultdict
from typing import List, Tuple, Optional, Generator, Union, Dict, Sequence, TYPE_CHECKING

from kraken.containers import BaselineOCRRecord, BBoxOCRRecord, ocr_record, Segmentation
from kraken.containers import BaselineOCRRecord, BBoxOCRRecord, ocr_record
from kraken.lib.util import get_im_str, is_bitonal
from kraken.lib.segmentation import extract_polygons
from kraken.lib.exceptions import KrakenInputException
from kraken.lib.dataset import ImageInputTransforms

if TYPE_CHECKING:
from PIL import Image
from kraken.containers import Segmentation
from kraken.lib.models import TorchSeqRecognizer

__all__ = ['mm_rpred', 'rpred']
Expand All @@ -45,9 +46,9 @@ class mm_rpred(object):
Multi-model version of kraken.rpred.rpred
"""
def __init__(self,
nets: Dict[Tuple[str, str], TorchSeqRecognizer],
im: Image.Image,
bounds: Segmentation,
nets: Dict[Tuple[str, str], 'TorchSeqRecognizer'],
im: 'Image.Image',
bounds: 'Segmentation',
pad: int = 16,
bidi_reordering: Union[bool, str] = True,
tags_ignore: Optional[List[Tuple[str, str]]] = None) -> Generator[ocr_record, None, None]:
Expand Down Expand Up @@ -291,9 +292,9 @@ def _scale_val(self, val, min_val, max_val):
return int(round(min(max(((val*self.net_scale)-self.pad)*self.in_scale, min_val), max_val-1)))


def rpred(network: TorchSeqRecognizer,
im: Image.Image,
bounds: Segmentation,
def rpred(network: 'TorchSeqRecognizer',
im: 'Image.Image',
bounds: 'Segmentation',
pad: int = 16,
bidi_reordering: Union[bool, str] = True) -> Generator[ocr_record, None, None]:
"""
Expand All @@ -319,8 +320,8 @@ def rpred(network: TorchSeqRecognizer,


def _resolve_tags_to_model(tags: Sequence[Dict[str, str]],
model_map: Dict[Tuple[str, str], TorchSeqRecognizer],
default: Optional[TorchSeqRecognizer] = None) -> TorchSeqRecognizer:
model_map: Dict[Tuple[str, str], 'TorchSeqRecognizer'],
default: Optional['TorchSeqRecognizer'] = None) -> 'TorchSeqRecognizer':
"""
Resolves a sequence of tags
"""
Expand Down
4 changes: 0 additions & 4 deletions tests/test_align.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# -*- coding: utf-8 -*-

import unittest
import json

import kraken
import dataclasses

from pathlib import Path

Expand Down
46 changes: 46 additions & 0 deletions tests/test_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
import shutil
import unittest
import tempfile

from pathlib import Path

from kraken import repo

thisfile = Path(__file__).resolve().parent
resources = thisfile / 'resources'

class TestRepo(unittest.TestCase):
"""
Testing interaction with the model repository.
"""

def setUp(self):
self.temp_model = tempfile.TemporaryDirectory()
self.temp_path = Path(self.temp_model.name)

def tearDown(self):
shutil.rmtree(self.temp_model.name)

def test_listing(self):
"""
Tests fetching the model list.
"""
records = repo.get_listing()
self.assertGreater(len(records), 15)

def test_get_description(self):
"""
Tests fetching the description of a model.
"""
record = repo.get_description('10.5281/zenodo.6657809')
self.assertEqual(record['doi'], '10.5281/zenodo.6657809')

def test_get_model(self):
"""
Tests fetching a model.
"""
id = repo.get_model('10.5281/zenodo.6657809',
path=self.temp_model.name)
self.assertEqual(id, 'HTR-United-Manu_McFrench.mlmodel')
self.assertEqual((self.temp_path / id).stat().st_size, 16176844)

0 comments on commit 5a14092

Please sign in to comment.