diff --git a/kraken/contrib/new_classes.json b/kraken/contrib/new_classes.json new file mode 100644 index 000000000..dcaee30a1 --- /dev/null +++ b/kraken/contrib/new_classes.json @@ -0,0 +1 @@ +{"baselines": {"defaultLine": 2}, "regions": {"foo": 3}} diff --git a/kraken/contrib/set_seg_options.py b/kraken/contrib/set_seg_options.py index a59000689..eeea0e4d5 100755 --- a/kraken/contrib/set_seg_options.py +++ b/kraken/contrib/set_seg_options.py @@ -12,11 +12,15 @@ @click.option('--topline/--baseline', default=False, help='Sets model line type to baseline or topline') @click.option('--pad', show_default=True, type=(int, int), default=(0, 0), help='Padding (left/right, top/bottom) around the page image') +@click.option('--output-identifiers', type=click.Path(exists=True), help='Path ' + 'to a json file containing a dict updating the string identifiers ' + 'of line/region classes.') @click.argument('model', nargs=1, type=click.Path(exists=True)) -def cli(bounding_region, topline, pad, model): +def cli(bounding_region, topline, pad, output_identifiers, model): """ A script setting the metadata of segmentation models. """ + import json from kraken.lib import vgsl net = vgsl.TorchVGSLModel.load_model(model) @@ -24,29 +28,50 @@ def cli(bounding_region, topline, pad, model): print('Model is not a segmentation model.') return - print('detectable line and region types:') + print('detectable line types:') for k, v in net.user_metadata['class_mapping']['baselines'].items(): print(f' {k}\t{v}') - print('Training region types:') + print('detectable region types:') for k, v in net.user_metadata['class_mapping']['regions'].items(): print(f' {k}\t{v}') + if output_identifiers: + with open(output_identifiers, 'r') as fp: + new_cls_map = json.load(fp) + print('-> Updating class maps') + if 'baselines' in new_cls_map: + print('new baseline identifiers:') + old_cls = {v: k for k,v in net.user_metadata['class_mapping']['baselines'].items()} + new_cls = {v: k for k,v in new_cls_map['baselines'].items()} + old_cls.update(new_cls) + net.user_metadata['class_mapping']['baselines'] = {v: k for k, v in old_cls.items()} + for k, v in net.user_metadata['class_mapping']['baselines'].items(): + print(f' {k}\t{v}') + if 'regions' in new_cls_map: + print('new region identifiers:') + old_cls = {v: k for k,v in net.user_metadata['class_mapping']['regions'].items()} + new_cls = {v: k for k,v in new_cls_map['regions'].items()} + old_cls.update(new_cls) + net.user_metadata['class_mapping']['regions'] = {v: k for k, v in old_cls.items()} + for k, v in net.user_metadata['class_mapping']['regions'].items(): + print(f' {k}\t{v}') + print(f'existing bounding regions: {net.user_metadata["bounding_regions"]}') if bounding_region: br = set(net.user_metadata["bounding_regions"]) br_new = set(bounding_region) - print(f'removing: {br.difference(br_new)}') - print(f'adding: {br_new.difference(br)}') + print(f'-> removing: {br.difference(br_new)}') + print(f'-> adding: {br_new.difference(br)}') net.user_metadata["bounding_regions"] = bounding_region print(f'Model is {"topline" if "topline" in net.user_metadata and net.user_metadata["topline"] else "baseline"}') - print(f'Setting to {"topline" if topline else "baseline"}') + print(f'-> Setting to {"topline" if topline else "baseline"}') net.user_metadata['topline'] = topline print(f"Model has padding {net.user_metadata['hyper_params']['padding'] if 'padding' in net.user_metadata['hyper_params'] else (0, 0)}") - print(f'Setting to {pad}') + print(f'-> Setting to {pad}') net.user_metadata['hyper_params']['padding'] = pad shutil.copy(model, f'{model}.bak')