Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Add huggingface transformers support #553

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add tests for batchencoding type
mike0sv committed Dec 30, 2022
commit ef3c5d778bfce47fc53de61b3154b288d61ba74f
8 changes: 8 additions & 0 deletions mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
@@ -261,3 +261,11 @@ def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DataType]:
raise NotImplementedError


def apply_shape_pattern(
abs_shape: Tuple[Optional[int], ...], shape: Tuple[int, ...]
):
return tuple(
s if s is not None else shape[i] for i, s in enumerate(abs_shape)
)
9 changes: 7 additions & 2 deletions mlem/contrib/tensorflow.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,10 @@
from pydantic import conlist, create_model
from tensorflow.python.keras.saving.saved_model_experimental import sequential

from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.contrib.numpy import (
apply_shape_pattern,
python_type_from_np_string_repr,
)
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.data_type import (
DataHook,
@@ -60,7 +63,9 @@ def tf_type(self):
return getattr(tf, self.dtype)

def check_shape(self, tensor, exc_type):
if tuple(tensor.shape)[1:] != self.shape[1:]:
if tuple(tensor.shape) != apply_shape_pattern(
self.shape, tensor.shape
):
raise exc_type(
f"given tensor is of shape: {(None,) + tuple(tensor.shape)[1:]}, expected: {self.shape}"
)
13 changes: 7 additions & 6 deletions mlem/contrib/torch.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,10 @@
from pydantic import conlist, create_model

from mlem.config import MlemConfigBase
from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.contrib.numpy import (
apply_shape_pattern,
python_type_from_np_string_repr,
)
from mlem.core.artifacts import Artifacts, FSSpecArtifact, Storage
from mlem.core.data_type import (
DataHook,
@@ -66,11 +69,9 @@ class TorchTensorDataType(
"""Type name of `torch.Tensor` elements"""

def check_shape(self, tensor, exc_type):
shape = tuple(
s if s is not None else tensor.shape[i]
for i, s in enumerate(self.shape)
)
if tuple(tensor.shape) != shape:
if tuple(tensor.shape) != apply_shape_pattern(
self.shape, tensor.shape
):
raise exc_type(
f"given tensor is of shape: {(None,) + tuple(tensor.shape)[1:]}, expected: {self.shape}"
)
64 changes: 58 additions & 6 deletions mlem/contrib/transformers.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import tempfile
from enum import Enum
from importlib import import_module
from typing import Any, ClassVar, Dict, Optional
from typing import Any, ClassVar, Dict, Iterator, Optional, Tuple

from transformers import (
AutoModel,
@@ -13,13 +13,16 @@
)
from transformers.modeling_utils import PreTrainedModel

from mlem.core.artifacts import Artifacts
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.data_type import (
DataAnalyzer,
DataHook,
DataType,
DataWriter,
DictReader,
DictSerializer,
DictType,
DictWriter,
)
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import BufferModelIO, ModelHook, ModelType, Signature
@@ -120,7 +123,7 @@ def get_requirements(self) -> Requirements:
return reqs


_ADDITIONAL_DEPS = {
ADDITIONAL_DEPS = {
TensorType.NUMPY: "numpy",
TensorType.PYTORCH: "torch",
TensorType.TENSORFLOW: "tensorflow",
@@ -131,6 +134,7 @@ class BatchEncodingType(DictType, DataHook, IsInstanceHookMixin):
class Config:
use_enum_values = True

type: ClassVar = "batch_encoding"
valid_types: ClassVar = BatchEncoding
return_tensors: Optional[TensorType] = None

@@ -150,6 +154,14 @@ def get_tensors_type(obj: BatchEncoding) -> Optional[TensorType]:
return None
raise ValueError(f"Unknown tensor type {type_}")

@property
def return_tensors_enum(self) -> Optional[TensorType]:
if self.return_tensors is not None and not isinstance(
self.return_tensors, TensorType
):
return TensorType(self.return_tensors)
return self.return_tensors

@classmethod
def process(cls, obj: BatchEncoding, **kwargs) -> DataType:
return BatchEncodingType(
@@ -162,19 +174,59 @@ def process(cls, obj: BatchEncoding, **kwargs) -> DataType:

def get_requirements(self) -> Requirements:
new = Requirements.new("transformers")
if self.return_tensors in _ADDITIONAL_DEPS:
new += Requirements.new(_ADDITIONAL_DEPS[self.return_tensors])
if self.return_tensors_enum in ADDITIONAL_DEPS:
new += Requirements.new(ADDITIONAL_DEPS[self.return_tensors_enum])
return new

def get_writer(
self, project: str = None, filename: str = None, **kwargs
) -> DataWriter:
return BatchEncodingWriter(**kwargs)


class BatchEncodingSerializer(DictSerializer):
data_class: ClassVar = BatchEncodingType
is_default: ClassVar = True

@staticmethod
def _check_type_and_keys(data_type, obj, exc_type):
data_type.check_type(obj, BatchEncoding, exc_type)
data_type.check_type(obj, (dict, BatchEncoding), exc_type)
if set(obj.keys()) != set(data_type.item_types.keys()):
raise exc_type(
f"given dict has keys: {set(obj.keys())}, expected: {set(data_type.item_types.keys())}"
)

def deserialize(self, data_type: DictType, obj):
assert isinstance(data_type, BatchEncodingType)
return BatchEncoding(
super().deserialize(data_type, obj),
tensor_type=data_type.return_tensors_enum,
)


class BatchEncodingReader(DictReader):
type: ClassVar = "batch_encoding"

def read(self, artifacts: Artifacts) -> DictType:
res = super().read(artifacts)
return res.bind(BatchEncoding(res.data))

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DictType]:
raise NotImplementedError


class BatchEncodingWriter(DictWriter):
type: ClassVar = "batch_encoding"

def write(
self, data: DataType, storage: Storage, path: str
) -> Tuple[DictReader, Artifacts]:
res, art = super().write(data, storage, path)
return (
BatchEncodingReader(
data_type=res.data_type, item_readers=res.item_readers
),
art,
)
2 changes: 1 addition & 1 deletion mlem/core/data_type.py
Original file line number Diff line number Diff line change
@@ -816,7 +816,7 @@ class DictWriter(DataWriter):

def write(
self, data: DataType, storage: Storage, path: str
) -> Tuple[DataReader, Artifacts]:
) -> Tuple["DictReader", Artifacts]:
if not isinstance(data, DictType):
raise ValueError(
f"expected data to be of DictType, got {type(data)} instead"
175 changes: 175 additions & 0 deletions tests/contrib/test_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from functools import partial

import numpy as np
import pytest
import tensorflow as tf
import torch
from pydantic import parse_obj_as
from transformers import (
AlbertModel,
AlbertTokenizer,
BatchEncoding,
DistilBertModel,
DistilBertTokenizer,
TensorType,
)

from mlem.contrib.transformers import ADDITIONAL_DEPS, BatchEncodingType
from mlem.core.data_type import DataAnalyzer, DataType
from tests.conftest import data_write_read_check

FULL_TESTS = True

TOKENIZERS = {
AlbertTokenizer: "albert-base-v2",
DistilBertTokenizer: "distilbert-base-uncased",
}

MODELS = {
AlbertModel: "albert-base-v2",
DistilBertModel: "distilbert-base-uncased",
}

ONE_MODEL = AlbertModel
ONE_TOKENIZER = AlbertTokenizer

for_model = pytest.mark.parametrize(
"model",
[ONE_MODEL.from_pretrained(MODELS[ONE_MODEL])]
if not FULL_TESTS
else [m.from_pretrained(v) for m, v in MODELS.items()],
)

for_tokenizer = pytest.mark.parametrize(
"tokenizer",
[ONE_TOKENIZER.from_pretrained(TOKENIZERS[ONE_TOKENIZER])]
if not FULL_TESTS
else [m.from_pretrained(v) for m, v in TOKENIZERS.items()],
)


def test_analyzing_model():
pass


def test_analyzing_tokenizer():
pass


def test_serving_model():
pass


def test_serving_tokenizer():
pass


def test_model_reqs():
pass


def test_tokenizer_reqs():
pass


# pylint: disable=protected-access
@for_tokenizer
@pytest.mark.parametrize(
"return_tensors,typename,eq",
[
("pt", "TorchTensor", lambda a, b: torch.all(a.eq(b))),
("tf", "TFTensor", lambda a, b: tf.equal(a, b)._numpy().all()),
("np", "NumpyNdarray", lambda a, b: np.equal(a, b).all()),
(None, "Array", None),
],
)
def test_batch_encoding(tokenizer, return_tensors, typename, eq):
data = tokenizer("aaa bbb", return_tensors=return_tensors)

data_type = DataAnalyzer.analyze(data)
assert isinstance(data_type, BatchEncodingType)
expected_reqs = ["transformers"]
if return_tensors is not None:
expected_reqs += [ADDITIONAL_DEPS[TensorType(return_tensors)]]
assert data_type.get_requirements().modules == expected_reqs

item_type = DataAnalyzer.analyze(data["input_ids"], is_dynamic=True).dict()
expected_payload = {
"item_types": {
"attention_mask": item_type,
"input_ids": item_type,
"token_type_ids": item_type,
},
"type": "batch_encoding",
}
if return_tensors is not None:
expected_payload["return_tensors"] = return_tensors
if "token_type_ids" not in data:
del expected_payload["item_types"]["token_type_ids"]
assert data_type.dict() == expected_payload
data_type2 = parse_obj_as(DataType, data_type.dict())
assert data_type2 == data_type

assert data_type.get_model().__name__ == data_type2.get_model().__name__
schema_item_type = {"items": {"type": "integer"}, "type": "array"}
if return_tensors is None:
schema_item_type = {"type": "integer"}
expected_schema = {
"definitions": {
f"attention_mask_{typename}": {
"items": schema_item_type,
"title": f"attention_mask_{typename}",
"type": "array",
},
f"input_ids_{typename}": {
"items": schema_item_type,
"title": f"input_ids_{typename}",
"type": "array",
},
f"token_type_ids_{typename}": {
"items": schema_item_type,
"title": f"token_type_ids_{typename}",
"type": "array",
},
},
"properties": {
"attention_mask": {
"$ref": f"#/definitions/attention_mask_{typename}"
},
"input_ids": {"$ref": f"#/definitions/input_ids_{typename}"},
"token_type_ids": {
"$ref": f"#/definitions/token_type_ids_{typename}"
},
},
"required": ["input_ids", "token_type_ids", "attention_mask"],
"title": "DictType",
"type": "object",
}
if "token_type_ids" not in data:
del expected_schema["definitions"][f"token_type_ids_{typename}"]
del expected_schema["properties"]["token_type_ids"]
expected_schema["required"].remove("token_type_ids")
assert data_type.get_model().schema() == expected_schema
n_payload = data_type.get_serializer().serialize(data)
deser = data_type.get_serializer().deserialize(n_payload)
assert _batch_encoding_equals(data, deser, eq)
parse_obj_as(data_type.get_model(), n_payload)

data_type = data_type.bind(data)
data_write_read_check(
data_type, custom_eq=partial(_batch_encoding_equals, equals=eq)
)


def _batch_encoding_equals(first, second, equals):
assert isinstance(first, BatchEncoding)
assert isinstance(second, BatchEncoding)

assert first.keys() == second.keys()

for key in first:
if equals is not None:
assert equals(first[key], second[key])
else:
assert first[key] == second[key]
return True