diff --git a/kraken/lib/vgsl.py b/kraken/lib/vgsl.py index 9a4f0759b..745dda816 100644 --- a/kraken/lib/vgsl.py +++ b/kraken/lib/vgsl.py @@ -15,7 +15,7 @@ from torch import nn from kraken.lib import layers -from kraken.lib.codec import PytorchCodec +from kraken.lib.codec import PytorchCodec, SentencePieceCodec from kraken.lib.exceptions import KrakenInvalidModelException # all tensors are ordered NCHW, the "feature" dimension is C, so the output of @@ -305,7 +305,9 @@ def _deserialize_layers(name, layer): nn.aux_layers = {k: cls(v).nn.get_submodule(k) for k, v in json.loads(mlmodel.user_defined_metadata['aux_layers']).items()} if 'codec' in mlmodel.user_defined_metadata: - nn.add_codec(PytorchCodec(json.loads(mlmodel.user_defined_metadata['codec']))) + codec = json.loads(mlmodel.user_defined_metadata['codec']) + if codec['type'] == 'SentencePiece': + nn.add_codec(SentencePieceCodec(codec['spp'])) nn.user_metadata: Dict[str, Any] = {'accuracy': [], 'metrics': [], @@ -409,7 +411,7 @@ def _serialize_layer(net, input, net_builder): mlmodel.short_description = 'kraken model' mlmodel.user_defined_metadata['vgsl'] = '[' + ' '.join(self.named_spec) + ']' if self.codec: - mlmodel.user_defined_metadata['codec'] = json.dumps(self.codec.c2l) + mlmodel.user_defined_metadata['codec'] = json.dumps({'type': 'SentencePiece', 'spp': self.spp.serialized_model_proto()}) if self.user_metadata: mlmodel.user_defined_metadata['kraken_meta'] = json.dumps(self.user_metadata) if self.aux_layers: