Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Expose GTO version in FastAPI's interface #681

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion mlem/contrib/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_local_file_info,
)
from mlem.core.meta_io import get_fs
from mlem.core.registry import ArtifactInRegistry

BATCH_SIZE = 10**5

Expand All @@ -32,7 +33,7 @@ def find_dvc_repo_root(path: str):
while True:
if os.path.isdir(os.path.join(_path, ".dvc")):
return _path
if _path == "/":
if _path == "/" or not _path:
break
_path = os.path.dirname(_path)
raise NotDvcRepoError(f"Path {path} is not in dvc repo")
Expand Down Expand Up @@ -118,3 +119,43 @@ def open(self) -> Iterator[IO]:
def relative(self, fs: AbstractFileSystem, path: str) -> "DVCArtifact":
relative = super().relative(fs, path)
return DVCArtifact(uri=relative.uri, size=self.size, hash=self.hash)


def find_artifact_name_by_path(path):
if not os.path.abspath(path):
raise ValueError(f"Path {path} is not absolute")

from dvc.repo import Repo

root = find_dvc_repo_root(path)
relpath = os.path.relpath(path, root)
if relpath.endswith(".mlem"):
relpath = relpath[:-5]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
relpath = relpath[:-5]
relpath = relpath.rstrip(".mlem")

repo = Repo(root)
for _dvcyaml, artifacts in repo.artifacts.read().items():
for name, value in artifacts.items():
if value.path == relpath:
return root, name
return root, None


def find_version(root, name):
from gto.api import _show_versions

version = _show_versions(root, name, ref="HEAD")
if version:
return version[0]["version"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this more flexible and support other version metadata, I suggest version to return a json object and not a string

  • GTO version implementation will now return {"semver": <VERSION>}
  • obvious next steps in the future would be to add git version information (branch, git rev).

this will also make it possible to return some version information even if there's no GTO tracking and also support other version in the future, like people manually git-tagging, not via gto



class DVCArtifactInRegistry(ArtifactInRegistry):
"""Artifact registered within an Artifact Registry."""

type: ClassVar = "dvc"
uri: str
"""Local path to file"""

@property
def version(self):
root, name = find_artifact_name_by_path(self.uri)
if name:
return find_version(root, name)
5 changes: 5 additions & 0 deletions mlem/contrib/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def process(
methods_sample_data: Optional[Dict[str, Any]] = None,
**kwargs
) -> ModelType:
if not hasattr(obj, "predict"):
# assuming the Pipeline has only transform steps
return SklearnTransformer.process(
obj, sample_data, methods_sample_data, **kwargs
)
methods_sample_data = methods_sample_data or {}
mt = SklearnPipelineType(io=SimplePickleIO(), methods={}).bind(obj)
predict = obj.predict
Expand Down
26 changes: 26 additions & 0 deletions mlem/core/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC
from typing import ClassVar

from mlem.core.base import MlemABC


class ArtifactInRegistry(MlemABC, ABC):
"""Artifact registered within an Artifact Registry.
Can provide version and stage it's promoted to.
"""

class Config:
type_root = True
default_type = "gto"

abs_name: ClassVar = "artifact_in_registry"
uri: str
"""location"""

@property
def version(self):
raise NotImplementedError

@property
def stage(self):
raise NotImplementedError
27 changes: 26 additions & 1 deletion mlem/runtime/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import logging
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Iterator, List, Optional, Tuple

Expand All @@ -18,6 +19,8 @@
from mlem.core.model import Argument, Signature
from mlem.core.objects import MlemModel

logger = logging.getLogger(__name__)


class ExecutionError(MlemError):
pass
Expand Down Expand Up @@ -100,6 +103,9 @@ class VersionedInterfaceDescriptor(BaseModel):
version: str = mlem.version.__version__
"""mlem version"""
meta: Any
"""model params"""
model_version: Optional[str] = None
"""model version"""


class Interface(ABC, MlemABC):
Expand All @@ -112,6 +118,7 @@ class Config:
type_root = True

abs_name: ClassVar[str] = "interface"
model_version: Optional[str] = None

@abstractmethod
def get_method_executor(self, method_name: str):
Expand Down Expand Up @@ -205,11 +212,15 @@ def get_descriptor(self) -> InterfaceDescriptor:
def get_model_meta(self):
return None

def get_model_version(self):
return None

def get_versioned_descriptor(self) -> VersionedInterfaceDescriptor:
return VersionedInterfaceDescriptor(
version=mlem.__version__,
methods=self.get_descriptor(),
meta=self.get_model_meta(),
model_version=self.get_model_version(),
)


Expand Down Expand Up @@ -299,7 +310,18 @@ def ensure_signature(cls, value: MlemModel):

@classmethod
def from_model(cls, model: MlemModel):
return cls(model=model)
import os

from mlem.contrib.dvc import DVCArtifactInRegistry

try:
model_version = DVCArtifactInRegistry(
uri=os.path.join(model.loc.project or "", model.loc.path)
).version
except Exception:
logger.info("Cannot get model version")
model_version = None
return cls(model=model, model_version=model_version)

@property
def is_single_model(self):
Expand Down Expand Up @@ -361,3 +383,6 @@ def get_method_args(

def get_model_meta(self):
return self.model.params

def get_model_version(self):
return self.model_version
3 changes: 3 additions & 0 deletions mlem/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,6 @@ def get_method_signature(self, method_name: str) -> InterfaceMethod:

def get_model_meta(self):
return getattr(getattr(self.interface, "model", None), "params", None)

def get_model_version(self):
return self.interface.model_version
6 changes: 6 additions & 0 deletions mlem/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,12 @@ def _should_ignore(self, mod: ModuleType):
)

def add_requirement(self, obj_or_module):
# if callable(obj_or_module):
# from sklearn.preprocessing import FunctionTransformer
# from sklearn.pipeline import Pipeline
# if isinstance(obj_or_module, Pipeline):
# import ipdb; ipdb.set_trace()

if not isinstance(obj_or_module, ModuleType):
try:
module = get_object_module(obj_or_module)
Expand Down