Skip to content

Commit

Permalink
Remove manual pickling logic from Features
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 21, 2022
1 parent 17831c6 commit e33325c
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 55 deletions.
6 changes: 0 additions & 6 deletions aesara/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,6 @@ class DestroyHandler(Bookkeeper): # noqa
"""

pickle_rm_attr = ["destroyers", "has_destroyers"]

def __init__(self, do_imports_on_attach=True, algo=None):
self.do_imports_on_attach = do_imports_on_attach

Expand Down Expand Up @@ -319,9 +317,6 @@ def on_attach(self, fgraph):

fgraph.destroy_handler = self

# Annotate the FunctionGraph #
self.unpickle(fgraph)

fgraph.fail_validate = {}
"""
Maps every variable in the graph to its "foundation" (deepest
Expand Down Expand Up @@ -365,7 +360,6 @@ def on_attach(self, fgraph):
if self.do_imports_on_attach:
super().on_attach(fgraph)

def unpickle(self, fgraph):
def get_destroyers_of(fgraph, r):
droot, _, root_destroyer = self.refresh_droot_impact(fgraph)
try:
Expand Down
39 changes: 2 additions & 37 deletions aesara/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def on_attach(self, fgraph):
run checks on the initial contents of the FunctionGraph.
"""
for node in io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node, "on_attach")
self.on_import(fgraph, node, "Bookkeeper.on_attach")

def on_detach(self, fgraph):
"""
Expand Down Expand Up @@ -368,8 +368,6 @@ class History(Feature):
"""

pickle_rm_attr = ["checkpoint", "revert"]

def on_attach(self, fgraph):
if hasattr(fgraph, "checkpoint") or hasattr(fgraph, "revert"):
raise AlreadyThere(
Expand All @@ -379,21 +377,10 @@ def on_attach(self, fgraph):
fgraph._history_is_reverting = False
fgraph._history_nb = 0
fgraph._history_history = []
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.checkpoint = types.MethodType(self.checkpoint, fgraph)
fgraph.revert = types.MethodType(self.revert, fgraph)

def unpickle(self, fgraph):
fgraph.checkpoint = types.MethodType(self.checkpoint, fgraph)
fgraph.revert = types.MethodType(self.revert, fgraph)

def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
del fgraph.checkpoint
del fgraph.revert
del fgraph._history_history
Expand Down Expand Up @@ -428,22 +415,13 @@ def revert(fgraph, checkpoint):


class Validator(Feature):
pickle_rm_attr = ["validate", "consistent"]

def on_attach(self, fgraph):
for attr in ("validate", "validate_time"):
if hasattr(fgraph, attr):
raise AlreadyThere(
"Validator feature is already present or in"
" conflict with another plugin."
)
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.validate = types.MethodType(self.validate_, fgraph)
fgraph.consistent = types.MethodType(self.consistent_, fgraph)

def unpickle(self, fgraph):
fgraph.validate = types.MethodType(self.validate_, fgraph)
fgraph.consistent = types.MethodType(self.consistent_, fgraph)

Expand Down Expand Up @@ -498,12 +476,6 @@ def consistent_(fgraph):


class ReplaceValidate(History, Validator):
pickle_rm_attr = (
["replace_validate", "replace_all_validate", "replace_all_validate_remove"]
+ History.pickle_rm_attr
+ Validator.pickle_rm_attr
)

def on_attach(self, fgraph):
for attr in (
"replace_validate",
Expand All @@ -517,13 +489,10 @@ def on_attach(self, fgraph):
)
fgraph._replace_nodes_removed = set()
fgraph._replace_validate_failed = False

History.on_attach(self, fgraph)
Validator.on_attach(self, fgraph)
self.unpickle(fgraph)

def unpickle(self, fgraph):
History.unpickle(self, fgraph)
Validator.unpickle(self, fgraph)
fgraph.replace_validate = types.MethodType(self.replace_validate, fgraph)
fgraph.replace_all_validate = types.MethodType(
self.replace_all_validate, fgraph
Expand All @@ -533,10 +502,6 @@ def unpickle(self, fgraph):
)

def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
History.on_detach(self, fgraph)
Validator.on_detach(self, fgraph)
del fgraph._replace_nodes_removed
Expand Down
23 changes: 11 additions & 12 deletions aesara/graph/fg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time
from collections import OrderedDict
from types import MethodType
from typing import Any, Dict, List, Optional, Tuple, Union

import aesara
Expand Down Expand Up @@ -794,20 +795,18 @@ def clone_get_equiv(
return e, equiv

def __getstate__(self):
# This is needed as some features introduce instance methods
# This is not picklable
d = self.__dict__.copy()
for feature in self._features:
for attr in getattr(feature, "pickle_rm_attr", []):
del d[attr]
# The class Updater take fct as parameter and they are lambda function, so unpicklable.

# execute_callbacks_times have reference to optimizer, and they can't
# be pickled as the decorators with parameters aren't pickable.
if "execute_callbacks_times" in d:
del d["execute_callbacks_times"]
# Remove methods that were attached by features
self_dict = {
k: v for k, v in self.__dict__.items() if not isinstance(v, MethodType)
}

return d
# `execute_callbacks_times` holds references to optimizers, so they
# can't be pickled
if "execute_callbacks_times" in self_dict:
del self_dict["execute_callbacks_times"]

return self_dict

def __setstate__(self, dct):
self.__dict__.update(dct)
Expand Down
15 changes: 15 additions & 0 deletions tests/graph/test_destroyhandler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from copy import copy

import pytest
Expand Down Expand Up @@ -470,3 +471,17 @@ def test_multiple_inplace():
OpSubOptimizer(multiple_in_place_1, multiple_in_place_0_1, fail).optimize(g)
assert g.consistent()
assert fail.failures == 1


def test_pickle():
x, y, z = inputs()
tv = transpose_view(x)
e = add_in_place(x, tv)
fg = create_fgraph([x, y], [e], False)
assert not fg.consistent()

fg_pkld = pickle.dumps(fg)
fg_unpkld = pickle.loads(fg_pkld)

assert any(isinstance(ft, DestroyHandler) for ft in fg_unpkld._features)
assert all(hasattr(fg, attr) for attr in ("_destroyhandler_destroyers",))
51 changes: 51 additions & 0 deletions tests/graph/test_features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

import pytest

from aesara.graph.basic import Apply, Variable
Expand Down Expand Up @@ -144,6 +146,32 @@ def validate(self, *args):
capres = capsys.readouterr()
assert "optimizer: validate failed on node Op1.0" in capres.out

def test_pickle(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
fg = FunctionGraph([var1, var2], [var3], clone=False)

rv_feature = ReplaceValidate()
fg.attach_feature(rv_feature)

fg_pkld = pickle.dumps(fg)
fg_unpkld = pickle.loads(fg_pkld)

assert ReplaceValidate in set(type(ft) for ft in fg_unpkld._features)
assert all(
hasattr(fg, attr)
for attr in (
"replace_validate",
"replace_all_validate",
"replace_all_validate_remove",
"checkpoint",
"revert",
"validate",
"consistent",
)
)


class TestHistory:
def test_basic(self):
Expand All @@ -169,3 +197,26 @@ def test_basic(self):

assert not fg._history_history
assert var3 in fg.variables

def test_pickle(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
fg = FunctionGraph([var1, var2], [var3], clone=False)

hf = History()
fg.attach_feature(hf)

fg_pkld = pickle.dumps(fg)
fg_unpkld = pickle.loads(fg_pkld)

assert any(isinstance(ft, History) for ft in fg_unpkld._features)
assert all(
hasattr(fg, attr)
for attr in (
"checkpoint",
"revert",
"validate",
"consistent",
)
)

0 comments on commit e33325c

Please sign in to comment.