Skip to content

Commit

Permalink
renaming class identifiers functionality in set_seg_options
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Oct 7, 2024
1 parent 25f2bc2 commit 30716fe
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
1 change: 1 addition & 0 deletions kraken/contrib/new_classes.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"baselines": {"defaultLine": 2}, "regions": {"foo": 3}}
39 changes: 32 additions & 7 deletions kraken/contrib/set_seg_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,66 @@
@click.option('--topline/--baseline', default=False, help='Sets model line type to baseline or topline')
@click.option('--pad', show_default=True, type=(int, int), default=(0, 0),
help='Padding (left/right, top/bottom) around the page image')
@click.option('--output-identifiers', type=click.Path(exists=True), help='Path '
'to a json file containing a dict updating the string identifiers '
'of line/region classes.')
@click.argument('model', nargs=1, type=click.Path(exists=True))
def cli(bounding_region, topline, pad, model):
def cli(bounding_region, topline, pad, output_identifiers, model):
"""
A script setting the metadata of segmentation models.
"""
import json
from kraken.lib import vgsl

net = vgsl.TorchVGSLModel.load_model(model)
if net.model_type != 'segmentation':
print('Model is not a segmentation model.')
return

print('detectable line and region types:')
print('detectable line types:')
for k, v in net.user_metadata['class_mapping']['baselines'].items():
print(f' {k}\t{v}')
print('Training region types:')
print('detectable region types:')
for k, v in net.user_metadata['class_mapping']['regions'].items():
print(f' {k}\t{v}')

if output_identifiers:
with open(output_identifiers, 'r') as fp:
new_cls_map = json.load(fp)
print('-> Updating class maps')
if 'baselines' in new_cls_map:
print('new baseline identifiers:')
old_cls = {v: k for k,v in net.user_metadata['class_mapping']['baselines'].items()}
new_cls = {v: k for k,v in new_cls_map['baselines'].items()}
old_cls.update(new_cls)
net.user_metadata['class_mapping']['baselines'] = {v: k for k, v in old_cls.items()}
for k, v in net.user_metadata['class_mapping']['baselines'].items():
print(f' {k}\t{v}')
if 'regions' in new_cls_map:
print('new region identifiers:')
old_cls = {v: k for k,v in net.user_metadata['class_mapping']['regions'].items()}
new_cls = {v: k for k,v in new_cls_map['regions'].items()}
old_cls.update(new_cls)
net.user_metadata['class_mapping']['regions'] = {v: k for k, v in old_cls.items()}
for k, v in net.user_metadata['class_mapping']['regions'].items():
print(f' {k}\t{v}')

print(f'existing bounding regions: {net.user_metadata["bounding_regions"]}')

if bounding_region:
br = set(net.user_metadata["bounding_regions"])
br_new = set(bounding_region)

print(f'removing: {br.difference(br_new)}')
print(f'adding: {br_new.difference(br)}')
print(f'-> removing: {br.difference(br_new)}')
print(f'-> adding: {br_new.difference(br)}')
net.user_metadata["bounding_regions"] = bounding_region

print(f'Model is {"topline" if "topline" in net.user_metadata and net.user_metadata["topline"] else "baseline"}')
print(f'Setting to {"topline" if topline else "baseline"}')
print(f'-> Setting to {"topline" if topline else "baseline"}')
net.user_metadata['topline'] = topline

print(f"Model has padding {net.user_metadata['hyper_params']['padding'] if 'padding' in net.user_metadata['hyper_params'] else (0, 0)}")
print(f'Setting to {pad}')
print(f'-> Setting to {pad}')
net.user_metadata['hyper_params']['padding'] = pad

shutil.copy(model, f'{model}.bak')
Expand Down

0 comments on commit 30716fe

Please sign in to comment.