Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

replace _ExternalRef logic with persistent_id #365

Merged
merged 2 commits into from
Aug 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 11 additions & 66 deletions mlem/contrib/callable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import pickle
import posixpath
from collections import defaultdict
from importlib import import_module
from io import BytesIO
from pickle import _Unpickler # type: ignore
from typing import Any, Callable, ClassVar, Dict, Optional, Tuple
from uuid import uuid4

from dill import Pickler
from dill import Pickler, Unpickler

from mlem.core.artifacts import Artifacts, Storage
from mlem.core.hooks import LOW_PRIORITY_VALUE
Expand Down Expand Up @@ -146,38 +144,12 @@ def __init__(self, model, *args, **kwargs):
self.model = model
self.refs: Dict[str, Tuple[ModelIO, Any]] = {}

# we couldn't import hook and analyzer at top as it leads to circular import failure
known_types = set()
for hook in ModelAnalyzer.hooks:
if not isinstance(hook, CallableModelType) and hook.valid_types:
known_types.update(hook.valid_types)
self.known_types = tuple(known_types)

# pickle "hook" for overriding serialization of objects
def save(self, obj, save_persistent_id=True):
"""
Checks if obj has IO.
If it does, serializes object with `ModelIO.dump`
and creates a ref to it. Otherwise, saves object as default pickle would do
:param obj: obj to save
:param save_persistent_id:
:return:
"""
if obj is self.model:
# at starting point, follow usual path not to fall into infinite loop
return super().save(obj, save_persistent_id)

io = self._get_non_pickle_io(obj)
if io is None:
# no non-Pickle IO found, follow usual path
return super().save(obj, save_persistent_id)

# found model with non-pickle serialization:
# replace with `_ExternalRef` stub and memorize IO to serialize model aside later
obj_uuid = str(uuid4())
self.refs[obj_uuid] = (io, obj)
return super().save(_ExternalRef(obj_uuid), save_persistent_id)

def _get_non_pickle_io(self, obj):
"""
Checks if obj has non-Pickle IO and returns it
Expand All @@ -200,49 +172,22 @@ def _get_non_pickle_io(self, obj):
# non-model object
return None

def persistent_id(self, obj: Any) -> Any:
io = self._get_non_pickle_io(obj)
if io is None:
return None
obj_uuid = str(uuid4())
self.refs[obj_uuid] = (io, obj)
return obj_uuid

# `Unpickler`, unlike `_Unpickler`, doesn't support `load_build` overriding
class _ModelUnpickler(_Unpickler):
"""
A class to unpickle model saved with :class:`_ModelPickler`
:param refs: dict of object uuid -> ref_obj
:param args: pickle._Unpickler args
:param kwargs: pickle._Unpickle kwargs
"""

dispatch = _Unpickler.dispatch.copy()

class _ModelUnpickler(Unpickler):
def __init__(self, refs, *args, **kwargs):
super().__init__(*args, **kwargs)
self.refs = refs

# pickle "hook" for overriding deserialization of objects
def load_build(self):
"""
Checks if last builded object is :class:`_ExternalRef` and if it is, swaps it with referenced object
:return:
"""
super().load_build()

# this is the last deserialized object for now
obj = self.stack[-1]
if not isinstance(obj, _ExternalRef):
return

# replace `_ExternalRef` with a real model it references
self.stack.pop()
self.stack.append(self.refs[obj.ref])

dispatch[pickle.BUILD[0]] = load_build # type: ignore


class _ExternalRef:
"""
A class to mark objects dumped their own :class:`ModelIO`
"""

def __init__(self, ref: str):
self.ref = ref
def persistent_load(self, pid: str) -> Any:
return self.refs[pid]


class CallableModelType(ModelType, ModelHook):
Expand Down