diff --git a/kraken/rpred.py b/kraken/rpred.py index 179d9e872..f21568b65 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -121,6 +121,13 @@ def __init__(self, valid_norm = True self.next_iter = self._recognize_box_line + if isinstance(nets, defaultdict) and nets.default_factory: + network = nets.default_factory() + batch, channels, height, width = network.nn.input + self.ts = defaultdict(lambda: ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm)) + else: + self.ts = {} + if self.have_tags: tags = set() for x in bounds.lines: @@ -148,11 +155,8 @@ def __init__(self, network = nets[tag] batch, channels, height, width = network.nn.input self.ts[tag] = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm) - elif isinstance(nets, defaultdict) and nets.default_factory: - network = nets.default_factory() - batch, channels, height, width = network.nn.input - self.ts = {('type', 'default'): ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm)} - else: + + if not isinstance(self.ts, defaultdict) and not self.ts: raise ValueError('No tags in input data and no default model in mapping given.') self.im = im