Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Add torch model import #394

Merged
merged 2 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 9 additions & 30 deletions mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,7 @@

from mlem.config import MlemConfigBase, project_config
from mlem.contrib.numpy import np_type_from_string, python_type_from_np_type
from mlem.core.artifacts import (
Artifact,
Artifacts,
PlaceholderArtifact,
Storage,
get_file_info,
)
from mlem.core.artifacts import Artifact, Artifacts, Storage
from mlem.core.data_type import (
DataHook,
DataReader,
Expand All @@ -50,9 +44,9 @@
SerializationError,
UnsupportedDataBatchLoadingType,
)
from mlem.core.import_objects import ExtImportHook
from mlem.core.import_objects import ExtImportHook, LoadAndAnalyzeImportHook
from mlem.core.meta_io import Location
from mlem.core.objects import MlemData, MlemObject
from mlem.core.objects import MlemData
from mlem.core.requirements import LibRequirementsMixin

_PD_EXT_TYPES = {
Expand Down Expand Up @@ -674,35 +668,20 @@ def write(
}


class PandasImport(ExtImportHook):
class PandasImport(ExtImportHook, LoadAndAnalyzeImportHook):
EXTS: ClassVar = tuple(f".{k}" for k in PANDAS_FORMATS)
type: ClassVar = "pandas"
force_type: ClassVar = MlemData

@classmethod
def is_object_valid(cls, obj: Location) -> bool:
return super().is_object_valid(obj) and obj.fs.isfile(obj.fullpath)

@classmethod
def process(
cls,
obj: Location,
copy_data: bool = True,
modifier: Optional[str] = None,
**kwargs,
) -> MlemObject:
ext = modifier or posixpath.splitext(obj.path)[1][1:]
def load_obj(cls, location: Location, modifier: Optional[str], **kwargs):
ext = modifier or posixpath.splitext(location.path)[1][1:]
fmt = PANDAS_FORMATS[ext]
read_args = fmt.read_args or {}
read_args.update(kwargs)
with obj.open("rb") as f:
data = fmt.read_func(f, **read_args)
meta = MlemData.from_data(data)
if not copy_data:
meta.artifacts = {
DataWriter.art_name: PlaceholderArtifact(
location=obj,
uri=obj.uri,
**get_file_info(obj.fullpath, obj.fs),
)
}
return meta
with location.open("rb") as f:
return fmt.read_func(f, **read_args)
25 changes: 24 additions & 1 deletion mlem/contrib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mlem.constants import PREDICT_METHOD_NAME
from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.artifacts import Artifacts, FSSpecArtifact, Storage
from mlem.core.data_type import (
DataHook,
DataReader,
Expand All @@ -15,7 +15,10 @@
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.import_objects import LoadAndAnalyzeImportHook
from mlem.core.meta_io import Location
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.objects import MlemModel
from mlem.core.requirements import InstallableRequirement, Requirements


Expand Down Expand Up @@ -190,6 +193,26 @@ def get_requirements(self) -> Requirements:
)


class TorchModelImport(LoadAndAnalyzeImportHook):
type: ClassVar = "torch"
force_type: ClassVar = MlemModel

@classmethod
def is_object_valid(cls, obj: Location) -> bool:
# TODO only manual import type specification for now
return False

@classmethod
def load_obj(cls, location: Location, modifier: Optional[str], **kwargs):
return TorchModelIO().load(
{
TorchModelIO.art_name: FSSpecArtifact(
uri=location.uri, size=0, hash=""
)
}
)


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down
37 changes: 27 additions & 10 deletions mlem/core/import_objects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
from abc import ABC, abstractmethod
from typing import ClassVar, Optional, Tuple
from typing import ClassVar, Optional, Tuple, Union

from mlem.core.artifacts import PlaceholderArtifact, get_file_info
from mlem.core.base import MlemABC
Expand All @@ -9,7 +9,7 @@
from mlem.core.meta_io import Location
from mlem.core.metadata import get_object_metadata
from mlem.core.model import ModelIO
from mlem.core.objects import MlemObject
from mlem.core.objects import MlemData, MlemModel, MlemObject


class ImportHook(Hook[MlemObject], MlemABC, ABC):
Expand Down Expand Up @@ -63,11 +63,8 @@ def is_object_valid(cls, obj: Location) -> bool:
return any(obj.path.endswith(e) for e in cls.EXTS)


class PickleImportHook(ExtImportHook):
"""Import hook for pickle files"""

EXTS: ClassVar = (".pkl", ".pickle")
type: ClassVar = "pickle"
class LoadAndAnalyzeImportHook(ImportHook, ABC):
force_type: ClassVar[Union[MlemModel, MlemData, None]] = None

@classmethod
def process(
Expand All @@ -77,9 +74,13 @@ def process(
modifier: Optional[str] = None,
**kwargs,
) -> MlemObject:
with obj.open("rb") as f:
data = pickle.load(f)
meta = get_object_metadata(data, **kwargs)
data = cls.load_obj(obj, modifier, **kwargs)
if cls.force_type is None:
meta = get_object_metadata(data, **kwargs)
elif cls.force_type is MlemModel:
meta = MlemModel.from_obj(data, **kwargs)
else:
meta = MlemData.from_data(data, **kwargs)
if not copy_data:
meta.artifacts = {
ModelIO.art_name: PlaceholderArtifact(
Expand All @@ -89,3 +90,19 @@ def process(
)
}
return meta

@classmethod
def load_obj(cls, location: Location, modifier: Optional[str], **kwargs):
raise NotImplementedError


class PickleImportHook(ExtImportHook, LoadAndAnalyzeImportHook):
"""Import hook for pickle files"""

EXTS: ClassVar = (".pkl", ".pickle")
type: ClassVar = "pickle"

@classmethod
def load_obj(cls, location: Location, modifier: Optional[str], **kwargs):
with location.open("rb") as f:
return pickle.load(f)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
"model_type.onnx = mlem.contrib.onnx:ONNXModel",
"data_type.dataframe = mlem.contrib.pandas:DataFrameType",
"import.pandas = mlem.contrib.pandas:PandasImport",
"import.torch = mlem.contrib.torch:TorchModelImport",
"data_reader.pandas = mlem.contrib.pandas:PandasReader",
"data_reader.pandas_series = mlem.contrib.pandas:PandasSeriesReader",
"data_writer.pandas_series = mlem.contrib.pandas:PandasSeriesWriter",
Expand Down
25 changes: 23 additions & 2 deletions tests/contrib/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import pytest
import torch

from mlem.api import save
from mlem.api import import_object, save
from mlem.constants import PREDICT_METHOD_NAME
from mlem.contrib.torch import TorchModelIO, TorchTensorReader
from mlem.contrib.torch import (
TorchModel,
TorchModelImport,
TorchModelIO,
TorchTensorReader,
)
from mlem.core.artifacts import LOCAL_STORAGE
from mlem.core.data_type import DataAnalyzer, DataType
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.model import ModelAnalyzer
from mlem.core.objects import MlemModel
from tests.conftest import data_write_read_check


Expand Down Expand Up @@ -153,6 +159,21 @@ def check_model(net, input_data, tmpdir):
save(net, str(tmpdir / "torch-net"), sample_data=input_data)


@pytest.mark.parametrize(
"net,torchsave",
[
(torch.nn.Linear(5, 1), torch.save),
(torch.jit.script(torch.nn.Linear(5, 1)), torch.jit.save),
],
)
def test_torch_import(tmp_path, net, torchsave):
path = tmp_path / "model"
torchsave(net, path)
meta = import_object(str(path), type_=TorchModelImport.type)
assert isinstance(meta, MlemModel)
assert isinstance(meta.model_type, TorchModel)


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down