diff --git a/mlem/contrib/dvc.py b/mlem/contrib/dvc.py index e19192be..81d9eb9a 100644 --- a/mlem/contrib/dvc.py +++ b/mlem/contrib/dvc.py @@ -21,6 +21,7 @@ get_local_file_info, ) from mlem.core.meta_io import get_fs +from mlem.core.registry import ArtifactInRegistry BATCH_SIZE = 10**5 @@ -32,7 +33,7 @@ def find_dvc_repo_root(path: str): while True: if os.path.isdir(os.path.join(_path, ".dvc")): return _path - if _path == "/": + if _path == "/" or not _path: break _path = os.path.dirname(_path) raise NotDvcRepoError(f"Path {path} is not in dvc repo") @@ -118,3 +119,43 @@ def open(self) -> Iterator[IO]: def relative(self, fs: AbstractFileSystem, path: str) -> "DVCArtifact": relative = super().relative(fs, path) return DVCArtifact(uri=relative.uri, size=self.size, hash=self.hash) + + +def find_artifact_name_by_path(path): + if not os.path.abspath(path): + raise ValueError(f"Path {path} is not absolute") + + from dvc.repo import Repo + + root = find_dvc_repo_root(path) + relpath = os.path.relpath(path, root) + if relpath.endswith(".mlem"): + relpath = relpath[:-5] + repo = Repo(root) + for _dvcyaml, artifacts in repo.artifacts.read().items(): + for name, value in artifacts.items(): + if value.path == relpath: + return root, name + return root, None + + +def find_version(root, name): + from gto.api import _show_versions + + version = _show_versions(root, name, ref="HEAD") + if version: + return version[0]["version"] + + +class DVCArtifactInRegistry(ArtifactInRegistry): + """Artifact registered within an Artifact Registry.""" + + type: ClassVar = "dvc" + uri: str + """Local path to file""" + + @property + def version(self): + root, name = find_artifact_name_by_path(self.uri) + if name: + return find_version(root, name) diff --git a/mlem/contrib/sklearn.py b/mlem/contrib/sklearn.py index a027174e..effe9b8e 100644 --- a/mlem/contrib/sklearn.py +++ b/mlem/contrib/sklearn.py @@ -103,6 +103,11 @@ def process( methods_sample_data: Optional[Dict[str, Any]] = None, **kwargs ) -> ModelType: + if not hasattr(obj, "predict"): + # assuming the Pipeline has only transform steps + return SklearnTransformer.process( + obj, sample_data, methods_sample_data, **kwargs + ) methods_sample_data = methods_sample_data or {} mt = SklearnPipelineType(io=SimplePickleIO(), methods={}).bind(obj) predict = obj.predict diff --git a/mlem/core/registry.py b/mlem/core/registry.py new file mode 100644 index 00000000..ff455639 --- /dev/null +++ b/mlem/core/registry.py @@ -0,0 +1,26 @@ +from abc import ABC +from typing import ClassVar + +from mlem.core.base import MlemABC + + +class ArtifactInRegistry(MlemABC, ABC): + """Artifact registered within an Artifact Registry. + Can provide version and stage it's promoted to. + """ + + class Config: + type_root = True + default_type = "gto" + + abs_name: ClassVar = "artifact_in_registry" + uri: str + """location""" + + @property + def version(self): + raise NotImplementedError + + @property + def stage(self): + raise NotImplementedError diff --git a/mlem/runtime/interface.py b/mlem/runtime/interface.py index 1c47bcae..6e1e8452 100644 --- a/mlem/runtime/interface.py +++ b/mlem/runtime/interface.py @@ -1,4 +1,5 @@ import inspect +import logging from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, Iterator, List, Optional, Tuple @@ -18,6 +19,8 @@ from mlem.core.model import Argument, Signature from mlem.core.objects import MlemModel +logger = logging.getLogger(__name__) + class ExecutionError(MlemError): pass @@ -100,6 +103,9 @@ class VersionedInterfaceDescriptor(BaseModel): version: str = mlem.version.__version__ """mlem version""" meta: Any + """model params""" + model_version: Optional[str] = None + """model version""" class Interface(ABC, MlemABC): @@ -112,6 +118,7 @@ class Config: type_root = True abs_name: ClassVar[str] = "interface" + model_version: Optional[str] = None @abstractmethod def get_method_executor(self, method_name: str): @@ -205,11 +212,15 @@ def get_descriptor(self) -> InterfaceDescriptor: def get_model_meta(self): return None + def get_model_version(self): + return None + def get_versioned_descriptor(self) -> VersionedInterfaceDescriptor: return VersionedInterfaceDescriptor( version=mlem.__version__, methods=self.get_descriptor(), meta=self.get_model_meta(), + model_version=self.get_model_version(), ) @@ -299,7 +310,18 @@ def ensure_signature(cls, value: MlemModel): @classmethod def from_model(cls, model: MlemModel): - return cls(model=model) + import os + + from mlem.contrib.dvc import DVCArtifactInRegistry + + try: + model_version = DVCArtifactInRegistry( + uri=os.path.join(model.loc.project or "", model.loc.path) + ).version + except Exception: + logger.info("Cannot get model version") + model_version = None + return cls(model=model, model_version=model_version) @property def is_single_model(self): @@ -361,3 +383,6 @@ def get_method_args( def get_model_meta(self): return self.model.params + + def get_model_version(self): + return self.model_version diff --git a/mlem/runtime/server.py b/mlem/runtime/server.py index 97f1265e..788dc5e1 100644 --- a/mlem/runtime/server.py +++ b/mlem/runtime/server.py @@ -338,3 +338,6 @@ def get_method_signature(self, method_name: str) -> InterfaceMethod: def get_model_meta(self): return getattr(getattr(self.interface, "model", None), "params", None) + + def get_model_version(self): + return self.interface.model_version diff --git a/mlem/utils/module.py b/mlem/utils/module.py index 4099cf7d..792725c3 100644 --- a/mlem/utils/module.py +++ b/mlem/utils/module.py @@ -565,6 +565,12 @@ def _should_ignore(self, mod: ModuleType): ) def add_requirement(self, obj_or_module): + # if callable(obj_or_module): + # from sklearn.preprocessing import FunctionTransformer + # from sklearn.pipeline import Pipeline + # if isinstance(obj_or_module, Pipeline): + # import ipdb; ipdb.set_trace() + if not isinstance(obj_or_module, ModuleType): try: module = get_object_module(obj_or_module)