Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
replace _ExternalRef logic with persistent_id (#365)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Guschin <[email protected]>
  • Loading branch information
mike0sv and aguschin authored Aug 8, 2022
1 parent db8461d commit ee994a0
Showing 1 changed file with 11 additions and 66 deletions.
77 changes: 11 additions & 66 deletions mlem/contrib/callable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import pickle
import posixpath
from collections import defaultdict
from importlib import import_module
from io import BytesIO
from pickle import _Unpickler # type: ignore
from typing import Any, Callable, ClassVar, Dict, Optional, Tuple
from uuid import uuid4

from dill import Pickler
from dill import Pickler, Unpickler

from mlem.core.artifacts import Artifacts, Storage
from mlem.core.hooks import LOW_PRIORITY_VALUE
Expand Down Expand Up @@ -146,38 +144,12 @@ def __init__(self, model, *args, **kwargs):
self.model = model
self.refs: Dict[str, Tuple[ModelIO, Any]] = {}

# we couldn't import hook and analyzer at top as it leads to circular import failure
known_types = set()
for hook in ModelAnalyzer.hooks:
if not isinstance(hook, CallableModelType) and hook.valid_types:
known_types.update(hook.valid_types)
self.known_types = tuple(known_types)

# pickle "hook" for overriding serialization of objects
def save(self, obj, save_persistent_id=True):
"""
Checks if obj has IO.
If it does, serializes object with `ModelIO.dump`
and creates a ref to it. Otherwise, saves object as default pickle would do
:param obj: obj to save
:param save_persistent_id:
:return:
"""
if obj is self.model:
# at starting point, follow usual path not to fall into infinite loop
return super().save(obj, save_persistent_id)

io = self._get_non_pickle_io(obj)
if io is None:
# no non-Pickle IO found, follow usual path
return super().save(obj, save_persistent_id)

# found model with non-pickle serialization:
# replace with `_ExternalRef` stub and memorize IO to serialize model aside later
obj_uuid = str(uuid4())
self.refs[obj_uuid] = (io, obj)
return super().save(_ExternalRef(obj_uuid), save_persistent_id)

def _get_non_pickle_io(self, obj):
"""
Checks if obj has non-Pickle IO and returns it
Expand All @@ -200,49 +172,22 @@ def _get_non_pickle_io(self, obj):
# non-model object
return None

def persistent_id(self, obj: Any) -> Any:
io = self._get_non_pickle_io(obj)
if io is None:
return None
obj_uuid = str(uuid4())
self.refs[obj_uuid] = (io, obj)
return obj_uuid

# `Unpickler`, unlike `_Unpickler`, doesn't support `load_build` overriding
class _ModelUnpickler(_Unpickler):
"""
A class to unpickle model saved with :class:`_ModelPickler`
:param refs: dict of object uuid -> ref_obj
:param args: pickle._Unpickler args
:param kwargs: pickle._Unpickle kwargs
"""

dispatch = _Unpickler.dispatch.copy()

class _ModelUnpickler(Unpickler):
def __init__(self, refs, *args, **kwargs):
super().__init__(*args, **kwargs)
self.refs = refs

# pickle "hook" for overriding deserialization of objects
def load_build(self):
"""
Checks if last builded object is :class:`_ExternalRef` and if it is, swaps it with referenced object
:return:
"""
super().load_build()

# this is the last deserialized object for now
obj = self.stack[-1]
if not isinstance(obj, _ExternalRef):
return

# replace `_ExternalRef` with a real model it references
self.stack.pop()
self.stack.append(self.refs[obj.ref])

dispatch[pickle.BUILD[0]] = load_build # type: ignore


class _ExternalRef:
"""
A class to mark objects dumped their own :class:`ModelIO`
"""

def __init__(self, ref: str):
self.ref = ref
def persistent_load(self, pid: str) -> Any:
return self.refs[pid]


class CallableModelType(ModelType, ModelHook):
Expand Down

0 comments on commit ee994a0

Please sign in to comment.