Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Fastai #268

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/check-test-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ jobs:
HEROKU_TEAM: iterative-sandbox
GITHUB_MATRIX_OS: ${{ matrix.os }}
GITHUB_MATRIX_PYTHON: ${{ matrix.python }}
BITBUCKET_USERNAME: ${{ secrets.BITBUCKET_USERNAME }}
BITBUCKET_PASSWORD: ${{ secrets.BITBUCKET_PASSWORD }}
- name: "Upload coverage to Codecov"
uses: codecov/codecov-action@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ ignore-comments=yes
ignore-docstrings=yes

# Ignore imports when computing similarities.
ignore-imports=no
ignore-imports=yes

# Ignore function signatures when computing similarities.
ignore-signatures=no
Expand Down
27 changes: 11 additions & 16 deletions mlem/contrib/catboost.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os
import posixpath
import tempfile
from enum import Enum
from typing import Any, ClassVar, Optional

import catboost
from catboost import CatBoost, CatBoostClassifier, CatBoostRegressor

from mlem.core.artifacts import Artifacts, Storage
from mlem.core.artifacts import Artifacts
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.model import (
BufferModelIO,
ModelHook,
ModelIO,
ModelType,
Signature,
)
from mlem.core.requirements import InstallableRequirement, Requirements


Expand All @@ -18,7 +21,7 @@ class CBType(str, Enum):
regressor = "reg"


class CatBoostModelIO(ModelIO):
class CatBoostModelIO(BufferModelIO):
"""
:class:`mlem.core.model.ModelIO` for CatBoost models.
"""
Expand All @@ -28,16 +31,8 @@ class CatBoostModelIO(ModelIO):
regressor_file_name: ClassVar = "rgr.cb"
model_type: CBType = CBType.regressor

def dump(self, storage: Storage, path, model) -> Artifacts:
with tempfile.TemporaryDirectory() as tmpdir:
model_name = self._get_model_file_name(model)
model_path = os.path.join(tmpdir, model_name)
model.save_model(model_path)
return {
self.art_name: storage.upload(
model_path, posixpath.join(path, model_name)
)
}
def save_model(self, model: Any, path: str):
model.save_model(path)

def load(self, artifacts: Artifacts):
"""
Expand Down
101 changes: 101 additions & 0 deletions mlem/contrib/fastai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Any, ClassVar, Optional, Type, Union

from fastai.data.transforms import Category
from fastai.learner import Learner, load_learner
from fastai.vision.core import PILImage
from pydantic import BaseModel

from mlem.core.artifacts import Artifacts
from mlem.core.data_type import DataHook, DataSerializer, DataType, DataWriter
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import BufferModelIO, ModelHook, ModelType, Signature
from mlem.core.requirements import Requirements


class FastAIModelIO(BufferModelIO):
type: ClassVar = "fastai"

def save_model(self, model: Any, path: str):
model.export(path)

def load(self, artifacts: Artifacts):
with artifacts[self.art_name].open() as f:
return load_learner(f)


class FastAIModel(ModelType, ModelHook, IsInstanceHookMixin):
type: ClassVar = "fastai"
valid_types: ClassVar = (Learner,)
io: FastAIModelIO = FastAIModelIO()

@classmethod
def process(
cls, obj: "Learner", sample_data: Optional[Any] = None, **kwargs
) -> ModelType:

return FastAIModel(
methods={
"predict": Signature.from_method(
obj.predict,
item=sample_data,
auto_infer=sample_data is not None,
)
}
)


class CategoryDataType(
DataType, DataSerializer, DataHook, IsInstanceHookMixin
):
type: ClassVar = "fastai_category"
valid_types: ClassVar = (Category,)
value: str

def serialize(self, instance: Any) -> dict:
raise NotImplementedError # TODO

def deserialize(self, obj: dict) -> Any:
raise NotImplementedError # TODO

def get_model(self, prefix: str = "") -> Union[Type[BaseModel], type]:
raise NotImplementedError # TODO

def get_requirements(self) -> Requirements:
return Requirements.new("fastai")

@classmethod
def process(cls, obj: Any, **kwargs):
return CategoryDataType(value=str(obj))

def get_writer(
self, project: str = None, filename: str = None, **kwargs
) -> DataWriter:
raise NotImplementedError # TODO


class PILImageDataType(
DataType, DataSerializer, DataHook, IsInstanceHookMixin
):
type: ClassVar = "fastai_pil_image"
valid_types: ClassVar = (PILImage,)

def serialize(self, instance: Any) -> dict:
raise NotImplementedError # TODO

def deserialize(self, obj: dict) -> Any:
raise NotImplementedError # TODO

def get_model(self, prefix: str = "") -> Union[Type[BaseModel], type]:
raise NotImplementedError # TODO

def get_requirements(self) -> Requirements:
return Requirements.new("fastai")

@classmethod
def process(cls, obj: Any, **kwargs):
return PILImageDataType()

def get_writer(
self, project: str = None, filename: str = None, **kwargs
) -> DataWriter:
raise NotImplementedError # TODO
19 changes: 10 additions & 9 deletions mlem/contrib/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import posixpath
import tempfile
from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type

Expand All @@ -18,7 +17,13 @@
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.model import (
BufferModelIO,
ModelHook,
ModelIO,
ModelType,
Signature,
)
from mlem.core.requirements import (
AddRequirementHook,
InstallableRequirement,
Expand Down Expand Up @@ -122,20 +127,16 @@ def read_batch(
raise NotImplementedError


class LightGBMModelIO(ModelIO):
class LightGBMModelIO(BufferModelIO):
"""
:class:`.ModelIO` implementation for `lightgbm.Booster` type
"""

type: ClassVar[str] = "lightgbm_io"
model_file_name = "model.lgb"

def dump(self, storage: Storage, path, model) -> Artifacts:
with tempfile.TemporaryDirectory(prefix="mlem_lightgbm_dump") as f:
model_path = os.path.join(f, self.model_file_name)
model.save_model(model_path)
fs_path = posixpath.join(path, self.model_file_name)
return {self.art_name: storage.upload(model_path, fs_path)}
def save_model(self, model: Any, path: str):
model.save_model(path)

def load(self, artifacts: Artifacts):
if len(artifacts) != 1:
Expand Down
23 changes: 11 additions & 12 deletions mlem/contrib/xgboost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import posixpath
import tempfile
from typing import Any, ClassVar, Dict, List, Optional, Type

Expand All @@ -8,11 +7,17 @@

from mlem.constants import PREDICT_METHOD_NAME
from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.artifacts import Artifacts
from mlem.core.data_type import DataHook, DataSerializer, DataType, DataWriter
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.model import (
BufferModelIO,
ModelHook,
ModelIO,
ModelType,
Signature,
)
from mlem.core.requirements import (
AddRequirementHook,
InstallableRequirement,
Expand Down Expand Up @@ -112,22 +117,16 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]:
raise NotImplementedError


class XGBoostModelIO(ModelIO):
class XGBoostModelIO(BufferModelIO):
"""
:class:`~.ModelIO` implementation for XGBoost models
"""

type: ClassVar[str] = "xgboost_io"
model_file_name = "model.xgb"

def dump(
self, storage: Storage, path, model: xgboost.Booster
) -> Artifacts:
with tempfile.TemporaryDirectory(prefix="mlem_xgboost_dump") as f:
local_path = os.path.join(f, self.model_file_name)
model.save_model(local_path)
remote_path = posixpath.join(path, self.model_file_name)
return {self.art_name: storage.upload(local_path, remote_path)}
def save_model(self, model: Any, path: str):
model.save_model(path)

def load(self, artifacts: Artifacts):
if len(artifacts) != 1:
Expand Down
14 changes: 14 additions & 0 deletions mlem/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
Base classes to work with ML models in MLEM
"""
import inspect
import os
import pickle
import tempfile
from abc import ABC, abstractmethod
from typing import (
Any,
Expand Down Expand Up @@ -50,6 +52,18 @@ def load(self, artifacts: Artifacts):
raise NotImplementedError


class BufferModelIO(ModelIO, ABC):
@abstractmethod
def save_model(self, model: Any, path: str):
raise NotImplementedError

def dump(self, storage: Storage, path, model) -> Artifacts:
with tempfile.TemporaryDirectory() as tmpdir:
model_path = os.path.join(tmpdir, "model")
self.save_model(model, model_path)
return {self.art_name: storage.upload(model_path, path)}


class SimplePickleIO(ModelIO):
"""IO with simple pickling of python model object"""

Expand Down
1 change: 1 addition & 0 deletions mlem/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class ExtensionLoader:
Extension("mlem.contrib.fastapi", ["fastapi", "uvicorn"], False),
Extension("mlem.contrib.callable", [], True),
Extension("mlem.contrib.rabbitmq", ["pika"], False, extra="rmq"),
Extension("mlem.contrib.fastai", ["fastai"], False),
)

_loaded_extensions: Dict[Extension, ModuleType] = {}
Expand Down
4 changes: 3 additions & 1 deletion mlem/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,9 @@ def add_requirement(self, obj_or_module):
)
if parent_package_name not in self._modules:
parent_package = sys.modules[parent_package_name]
self.add_requirement(parent_package)
# exclude namespace packages
if parent_package.__file__ is not None:
self.add_requirement(parent_package)

def save(self, obj, save_persistent_id=True):
if id(obj) in self.seen or isinstance(obj, IGNORE_TYPES_REQ):
Expand Down
5 changes: 5 additions & 0 deletions tests/contrib/test_fastai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# TODO


def test_learner():
pass