Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Add torch model import (#394)
Browse files Browse the repository at this point in the history
* lil refactor of import hooks
add torch model import
closes #385

* fix typo
  • Loading branch information
mike0sv authored Sep 5, 2022
1 parent 387bd4d commit 2b7f380
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 43 deletions.
39 changes: 9 additions & 30 deletions mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,7 @@

from mlem.config import MlemConfigBase, project_config
from mlem.contrib.numpy import np_type_from_string, python_type_from_np_type
from mlem.core.artifacts import (
Artifact,
Artifacts,
PlaceholderArtifact,
Storage,
get_file_info,
)
from mlem.core.artifacts import Artifact, Artifacts, Storage
from mlem.core.data_type import (
DataHook,
DataReader,
Expand All @@ -50,9 +44,9 @@
SerializationError,
UnsupportedDataBatchLoadingType,
)
from mlem.core.import_objects import ExtImportHook
from mlem.core.import_objects import ExtImportHook, LoadAndAnalyzeImportHook
from mlem.core.meta_io import Location
from mlem.core.objects import MlemData, MlemObject
from mlem.core.objects import MlemData
from mlem.core.requirements import LibRequirementsMixin

_PD_EXT_TYPES = {
Expand Down Expand Up @@ -674,35 +668,20 @@ def write(
}


class PandasImport(ExtImportHook):
class PandasImport(ExtImportHook, LoadAndAnalyzeImportHook):
EXTS: ClassVar = tuple(f".{k}" for k in PANDAS_FORMATS)
type: ClassVar = "pandas"
force_type: ClassVar = MlemData

@classmethod
def is_object_valid(cls, obj: Location) -> bool:
return super().is_object_valid(obj) and obj.fs.isfile(obj.fullpath)

@classmethod
def process(
cls,
obj: Location,
copy_data: bool = True,
modifier: Optional[str] = None,
**kwargs,
) -> MlemObject:
ext = modifier or posixpath.splitext(obj.path)[1][1:]
def load_obj(cls, location: Location, modifier: Optional[str], **kwargs):
ext = modifier or posixpath.splitext(location.path)[1][1:]
fmt = PANDAS_FORMATS[ext]
read_args = fmt.read_args or {}
read_args.update(kwargs)
with obj.open("rb") as f:
data = fmt.read_func(f, **read_args)
meta = MlemData.from_data(data)
if not copy_data:
meta.artifacts = {
DataWriter.art_name: PlaceholderArtifact(
location=obj,
uri=obj.uri,
**get_file_info(obj.fullpath, obj.fs),
)
}
return meta
with location.open("rb") as f:
return fmt.read_func(f, **read_args)
25 changes: 24 additions & 1 deletion mlem/contrib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mlem.constants import PREDICT_METHOD_NAME
from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.artifacts import Artifacts, FSSpecArtifact, Storage
from mlem.core.data_type import (
DataHook,
DataReader,
Expand All @@ -15,7 +15,10 @@
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.import_objects import LoadAndAnalyzeImportHook
from mlem.core.meta_io import Location
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.objects import MlemModel
from mlem.core.requirements import InstallableRequirement, Requirements


Expand Down Expand Up @@ -190,6 +193,26 @@ def get_requirements(self) -> Requirements:
)


class TorchModelImport(LoadAndAnalyzeImportHook):
type: ClassVar = "torch"
force_type: ClassVar = MlemModel

@classmethod
def is_object_valid(cls, obj: Location) -> bool:
# TODO only manual import type specification for now
return False

@classmethod
def load_obj(cls, location: Location, modifier: Optional[str], **kwargs):
return TorchModelIO().load(
{
TorchModelIO.art_name: FSSpecArtifact(
uri=location.uri, size=0, hash=""
)
}
)


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down
37 changes: 27 additions & 10 deletions mlem/core/import_objects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
from abc import ABC, abstractmethod
from typing import ClassVar, Optional, Tuple
from typing import ClassVar, Optional, Tuple, Union

from mlem.core.artifacts import PlaceholderArtifact, get_file_info
from mlem.core.base import MlemABC
Expand All @@ -9,7 +9,7 @@
from mlem.core.meta_io import Location
from mlem.core.metadata import get_object_metadata
from mlem.core.model import ModelIO
from mlem.core.objects import MlemObject
from mlem.core.objects import MlemData, MlemModel, MlemObject


class ImportHook(Hook[MlemObject], MlemABC, ABC):
Expand Down Expand Up @@ -63,11 +63,8 @@ def is_object_valid(cls, obj: Location) -> bool:
return any(obj.path.endswith(e) for e in cls.EXTS)


class PickleImportHook(ExtImportHook):
"""Import hook for pickle files"""

EXTS: ClassVar = (".pkl", ".pickle")
type: ClassVar = "pickle"
class LoadAndAnalyzeImportHook(ImportHook, ABC):
force_type: ClassVar[Union[MlemModel, MlemData, None]] = None

@classmethod
def process(
Expand All @@ -77,9 +74,13 @@ def process(
modifier: Optional[str] = None,
**kwargs,
) -> MlemObject:
with obj.open("rb") as f:
data = pickle.load(f)
meta = get_object_metadata(data, **kwargs)
data = cls.load_obj(obj, modifier, **kwargs)
if cls.force_type is None:
meta = get_object_metadata(data, **kwargs)
elif cls.force_type is MlemModel:
meta = MlemModel.from_obj(data, **kwargs)
else:
meta = MlemData.from_data(data, **kwargs)
if not copy_data:
meta.artifacts = {
ModelIO.art_name: PlaceholderArtifact(
Expand All @@ -89,3 +90,19 @@ def process(
)
}
return meta

@classmethod
def load_obj(cls, location: Location, modifier: Optional[str], **kwargs):
raise NotImplementedError


class PickleImportHook(ExtImportHook, LoadAndAnalyzeImportHook):
"""Import hook for pickle files"""

EXTS: ClassVar = (".pkl", ".pickle")
type: ClassVar = "pickle"

@classmethod
def load_obj(cls, location: Location, modifier: Optional[str], **kwargs):
with location.open("rb") as f:
return pickle.load(f)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
"model_type.onnx = mlem.contrib.onnx:ONNXModel",
"data_type.dataframe = mlem.contrib.pandas:DataFrameType",
"import.pandas = mlem.contrib.pandas:PandasImport",
"import.torch = mlem.contrib.torch:TorchModelImport",
"data_reader.pandas = mlem.contrib.pandas:PandasReader",
"data_reader.pandas_series = mlem.contrib.pandas:PandasSeriesReader",
"data_writer.pandas_series = mlem.contrib.pandas:PandasSeriesWriter",
Expand Down
25 changes: 23 additions & 2 deletions tests/contrib/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import pytest
import torch

from mlem.api import save
from mlem.api import import_object, save
from mlem.constants import PREDICT_METHOD_NAME
from mlem.contrib.torch import TorchModelIO, TorchTensorReader
from mlem.contrib.torch import (
TorchModel,
TorchModelImport,
TorchModelIO,
TorchTensorReader,
)
from mlem.core.artifacts import LOCAL_STORAGE
from mlem.core.data_type import DataAnalyzer, DataType
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.model import ModelAnalyzer
from mlem.core.objects import MlemModel
from tests.conftest import data_write_read_check


Expand Down Expand Up @@ -153,6 +159,21 @@ def check_model(net, input_data, tmpdir):
save(net, str(tmpdir / "torch-net"), sample_data=input_data)


@pytest.mark.parametrize(
"net,torchsave",
[
(torch.nn.Linear(5, 1), torch.save),
(torch.jit.script(torch.nn.Linear(5, 1)), torch.jit.save),
],
)
def test_torch_import(tmp_path, net, torchsave):
path = tmp_path / "model"
torchsave(net, path)
meta = import_object(str(path), type_=TorchModelImport.type)
assert isinstance(meta, MlemModel)
assert isinstance(meta.model_type, TorchModel)


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down

0 comments on commit 2b7f380

Please sign in to comment.