Skip to content

Commit

Permalink
add sentencepiececodec serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Apr 7, 2024
1 parent 17b0e48 commit 158e84e
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions kraken/lib/vgsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import nn

from kraken.lib import layers
from kraken.lib.codec import PytorchCodec
from kraken.lib.codec import PytorchCodec, SentencePieceCodec
from kraken.lib.exceptions import KrakenInvalidModelException

# all tensors are ordered NCHW, the "feature" dimension is C, so the output of
Expand Down Expand Up @@ -305,7 +305,9 @@ def _deserialize_layers(name, layer):
nn.aux_layers = {k: cls(v).nn.get_submodule(k) for k, v in json.loads(mlmodel.user_defined_metadata['aux_layers']).items()}

if 'codec' in mlmodel.user_defined_metadata:
nn.add_codec(PytorchCodec(json.loads(mlmodel.user_defined_metadata['codec'])))
codec = json.loads(mlmodel.user_defined_metadata['codec'])
if codec['type'] == 'SentencePiece':
nn.add_codec(SentencePieceCodec(codec['spp']))

nn.user_metadata: Dict[str, Any] = {'accuracy': [],
'metrics': [],
Expand Down Expand Up @@ -409,7 +411,7 @@ def _serialize_layer(net, input, net_builder):
mlmodel.short_description = 'kraken model'
mlmodel.user_defined_metadata['vgsl'] = '[' + ' '.join(self.named_spec) + ']'
if self.codec:
mlmodel.user_defined_metadata['codec'] = json.dumps(self.codec.c2l)
mlmodel.user_defined_metadata['codec'] = json.dumps({'type': 'SentencePiece', 'spp': self.spp.serialized_model_proto()})
if self.user_metadata:
mlmodel.user_defined_metadata['kraken_meta'] = json.dumps(self.user_metadata)
if self.aux_layers:
Expand Down

0 comments on commit 158e84e

Please sign in to comment.