diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index 3a62a4a8..601c9c32 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -31,13 +31,7 @@ from mlem.config import MlemConfigBase, project_config from mlem.contrib.numpy import np_type_from_string, python_type_from_np_type -from mlem.core.artifacts import ( - Artifact, - Artifacts, - PlaceholderArtifact, - Storage, - get_file_info, -) +from mlem.core.artifacts import Artifact, Artifacts, Storage from mlem.core.data_type import ( DataHook, DataReader, @@ -50,9 +44,9 @@ SerializationError, UnsupportedDataBatchLoadingType, ) -from mlem.core.import_objects import ExtImportHook +from mlem.core.import_objects import ExtImportHook, LoadAndAnalyzeImportHook from mlem.core.meta_io import Location -from mlem.core.objects import MlemData, MlemObject +from mlem.core.objects import MlemData from mlem.core.requirements import LibRequirementsMixin _PD_EXT_TYPES = { @@ -674,35 +668,20 @@ def write( } -class PandasImport(ExtImportHook): +class PandasImport(ExtImportHook, LoadAndAnalyzeImportHook): EXTS: ClassVar = tuple(f".{k}" for k in PANDAS_FORMATS) type: ClassVar = "pandas" + force_type: ClassVar = MlemData @classmethod def is_object_valid(cls, obj: Location) -> bool: return super().is_object_valid(obj) and obj.fs.isfile(obj.fullpath) @classmethod - def process( - cls, - obj: Location, - copy_data: bool = True, - modifier: Optional[str] = None, - **kwargs, - ) -> MlemObject: - ext = modifier or posixpath.splitext(obj.path)[1][1:] + def load_obj(cls, location: Location, modifier: Optional[str], **kwargs): + ext = modifier or posixpath.splitext(location.path)[1][1:] fmt = PANDAS_FORMATS[ext] read_args = fmt.read_args or {} read_args.update(kwargs) - with obj.open("rb") as f: - data = fmt.read_func(f, **read_args) - meta = MlemData.from_data(data) - if not copy_data: - meta.artifacts = { - DataWriter.art_name: PlaceholderArtifact( - location=obj, - uri=obj.uri, - **get_file_info(obj.fullpath, obj.fs), - ) - } - return meta + with location.open("rb") as f: + return fmt.read_func(f, **read_args) diff --git a/mlem/contrib/torch.py b/mlem/contrib/torch.py index c2b373b1..bc9e489f 100644 --- a/mlem/contrib/torch.py +++ b/mlem/contrib/torch.py @@ -5,7 +5,7 @@ from mlem.constants import PREDICT_METHOD_NAME from mlem.contrib.numpy import python_type_from_np_string_repr -from mlem.core.artifacts import Artifacts, Storage +from mlem.core.artifacts import Artifacts, FSSpecArtifact, Storage from mlem.core.data_type import ( DataHook, DataReader, @@ -15,7 +15,10 @@ ) from mlem.core.errors import DeserializationError, SerializationError from mlem.core.hooks import IsInstanceHookMixin +from mlem.core.import_objects import LoadAndAnalyzeImportHook +from mlem.core.meta_io import Location from mlem.core.model import ModelHook, ModelIO, ModelType, Signature +from mlem.core.objects import MlemModel from mlem.core.requirements import InstallableRequirement, Requirements @@ -190,6 +193,26 @@ def get_requirements(self) -> Requirements: ) +class TorchModelImport(LoadAndAnalyzeImportHook): + type: ClassVar = "torch" + force_type: ClassVar = MlemModel + + @classmethod + def is_object_valid(cls, obj: Location) -> bool: + # TODO only manual import type specification for now + return False + + @classmethod + def load_obj(cls, location: Location, modifier: Optional[str], **kwargs): + return TorchModelIO().load( + { + TorchModelIO.art_name: FSSpecArtifact( + uri=location.uri, size=0, hash="" + ) + } + ) + + # Copyright 2019 Zyfra # Copyright 2021 Iterative # diff --git a/mlem/core/import_objects.py b/mlem/core/import_objects.py index b7f2e318..43afc0ca 100644 --- a/mlem/core/import_objects.py +++ b/mlem/core/import_objects.py @@ -1,6 +1,6 @@ import pickle from abc import ABC, abstractmethod -from typing import ClassVar, Optional, Tuple +from typing import ClassVar, Optional, Tuple, Union from mlem.core.artifacts import PlaceholderArtifact, get_file_info from mlem.core.base import MlemABC @@ -9,7 +9,7 @@ from mlem.core.meta_io import Location from mlem.core.metadata import get_object_metadata from mlem.core.model import ModelIO -from mlem.core.objects import MlemObject +from mlem.core.objects import MlemData, MlemModel, MlemObject class ImportHook(Hook[MlemObject], MlemABC, ABC): @@ -63,11 +63,8 @@ def is_object_valid(cls, obj: Location) -> bool: return any(obj.path.endswith(e) for e in cls.EXTS) -class PickleImportHook(ExtImportHook): - """Import hook for pickle files""" - - EXTS: ClassVar = (".pkl", ".pickle") - type: ClassVar = "pickle" +class LoadAndAnalyzeImportHook(ImportHook, ABC): + force_type: ClassVar[Union[MlemModel, MlemData, None]] = None @classmethod def process( @@ -77,9 +74,13 @@ def process( modifier: Optional[str] = None, **kwargs, ) -> MlemObject: - with obj.open("rb") as f: - data = pickle.load(f) - meta = get_object_metadata(data, **kwargs) + data = cls.load_obj(obj, modifier, **kwargs) + if cls.force_type is None: + meta = get_object_metadata(data, **kwargs) + elif cls.force_type is MlemModel: + meta = MlemModel.from_obj(data, **kwargs) + else: + meta = MlemData.from_data(data, **kwargs) if not copy_data: meta.artifacts = { ModelIO.art_name: PlaceholderArtifact( @@ -89,3 +90,19 @@ def process( ) } return meta + + @classmethod + def load_obj(cls, location: Location, modifier: Optional[str], **kwargs): + raise NotImplementedError + + +class PickleImportHook(ExtImportHook, LoadAndAnalyzeImportHook): + """Import hook for pickle files""" + + EXTS: ClassVar = (".pkl", ".pickle") + type: ClassVar = "pickle" + + @classmethod + def load_obj(cls, location: Location, modifier: Optional[str], **kwargs): + with location.open("rb") as f: + return pickle.load(f) diff --git a/setup.py b/setup.py index afbc75de..1ee5dbd3 100644 --- a/setup.py +++ b/setup.py @@ -178,6 +178,7 @@ "model_type.onnx = mlem.contrib.onnx:ONNXModel", "data_type.dataframe = mlem.contrib.pandas:DataFrameType", "import.pandas = mlem.contrib.pandas:PandasImport", + "import.torch = mlem.contrib.torch:TorchModelImport", "data_reader.pandas = mlem.contrib.pandas:PandasReader", "data_reader.pandas_series = mlem.contrib.pandas:PandasSeriesReader", "data_writer.pandas_series = mlem.contrib.pandas:PandasSeriesWriter", diff --git a/tests/contrib/test_torch.py b/tests/contrib/test_torch.py index cd0a8fcc..f8615f3a 100644 --- a/tests/contrib/test_torch.py +++ b/tests/contrib/test_torch.py @@ -3,13 +3,19 @@ import pytest import torch -from mlem.api import save +from mlem.api import import_object, save from mlem.constants import PREDICT_METHOD_NAME -from mlem.contrib.torch import TorchModelIO, TorchTensorReader +from mlem.contrib.torch import ( + TorchModel, + TorchModelImport, + TorchModelIO, + TorchTensorReader, +) from mlem.core.artifacts import LOCAL_STORAGE from mlem.core.data_type import DataAnalyzer, DataType from mlem.core.errors import DeserializationError, SerializationError from mlem.core.model import ModelAnalyzer +from mlem.core.objects import MlemModel from tests.conftest import data_write_read_check @@ -153,6 +159,21 @@ def check_model(net, input_data, tmpdir): save(net, str(tmpdir / "torch-net"), sample_data=input_data) +@pytest.mark.parametrize( + "net,torchsave", + [ + (torch.nn.Linear(5, 1), torch.save), + (torch.jit.script(torch.nn.Linear(5, 1)), torch.jit.save), + ], +) +def test_torch_import(tmp_path, net, torchsave): + path = tmp_path / "model" + torchsave(net, path) + meta = import_object(str(path), type_=TorchModelImport.type) + assert isinstance(meta, MlemModel) + assert isinstance(meta.model_type, TorchModel) + + # Copyright 2019 Zyfra # Copyright 2021 Iterative #