Skip to content

Commit

Permalink
Make _VariableEquivalenceTracker a proper Feature and less stateful
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 21, 2022
1 parent e33325c commit e32dc33
Showing 1 changed file with 69 additions and 92 deletions.
161 changes: 69 additions & 92 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from aesara.configdefaults import config
from aesara.graph.basic import Variable, graph_inputs, io_toposort
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization
from aesara.graph.features import AlreadyThere, BadOptimization, Feature
from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, HasInnerGraph, Op
from aesara.graph.utils import MethodNotDefined
Expand Down Expand Up @@ -433,7 +433,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
equivalence_tracker = _VariableEquivalenceTracker()
fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace)
fgraph.attach_feature(equivalence_tracker)
return fgraph, updates, equivalence_tracker
return fgraph, updates


class DataDestroyed:
Expand Down Expand Up @@ -1181,96 +1181,84 @@ def __ne__(self, other):
return not (self == other)


class _VariableEquivalenceTracker:
class _VariableEquivalenceTracker(Feature):
"""
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
tries to detect problems.
"""

fgraph = None
"""WRITEME"""

equiv = None
"""WRITEME"""

active_nodes = None
"""WRITEME"""

inactive_nodes = None
"""WRITEME"""

all_variables_ever = None
"""WRITEME"""

reasons = None
"""WRITEME"""

replaced_by = None
"""WRITEME"""
def on_attach(self, fgraph):

event_list = None
"""WRITEME"""
if hasattr(fgraph, "_eq_tracker_equiv"):
raise AlreadyThere()

def __init__(self):
self.fgraph = None
fgraph._eq_tracker_equiv = {}
fgraph._eq_tracker_active_nodes = set()
fgraph._eq_tracker_inactive_nodes = set()
fgraph._eq_tracker_fgraph = fgraph
fgraph._eq_tracker_all_variables_ever = []
fgraph._eq_tracker_reasons = {}
fgraph._eq_tracker_replaced_by = {}
fgraph._eq_tracker_event_list = []

def on_attach(self, fgraph):
assert self.fgraph is None
self.equiv = {}
self.active_nodes = set()
self.inactive_nodes = set()
self.fgraph = fgraph
self.all_variables_ever = []
self.reasons = {}
self.replaced_by = {}
self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
self.on_import(fgraph, node, "var_equiv_on_attach")

def on_detach(self, fgraph):
assert fgraph is self.fgraph
self.fgraph = None
del fgraph._eq_tracker_equiv
del fgraph._eq_tracker_active_nodes
del fgraph._eq_tracker_inactive_nodes
del fgraph._eq_tracker_fgraph
del fgraph._eq_tracker_all_variables_ever
del fgraph._eq_tracker_reasons
del fgraph._eq_tracker_replaced_by
del fgraph._eq_tracker_event_list

def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason)))
assert node in self.active_nodes
assert node not in self.inactive_nodes
self.active_nodes.remove(node)
self.inactive_nodes.add(node)
fgraph._eq_tracker_event_list.append(
_FunctionGraphEvent("prune", node, reason=str(reason))
)
assert node in fgraph._eq_tracker_active_nodes
assert node not in fgraph._eq_tracker_inactive_nodes
fgraph._eq_tracker_active_nodes.remove(node)
fgraph._eq_tracker_inactive_nodes.add(node)

def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason)))
fgraph._eq_tracker_event_list.append(
_FunctionGraphEvent("import", node, reason=str(reason))
)

assert node not in self.active_nodes
self.active_nodes.add(node)
assert node not in fgraph._eq_tracker_active_nodes
fgraph._eq_tracker_active_nodes.add(node)

if node in self.inactive_nodes:
self.inactive_nodes.remove(node)
if node in fgraph._eq_tracker_inactive_nodes:
fgraph._eq_tracker_inactive_nodes.remove(node)
for r in node.outputs:
assert r in self.equiv
assert r in fgraph._eq_tracker_equiv
else:
for r in node.outputs:
assert r not in self.equiv
self.equiv[r] = {r}
self.all_variables_ever.append(r)
self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
assert r not in fgraph._eq_tracker_equiv
fgraph._eq_tracker_equiv[r] = {r}
fgraph._eq_tracker_all_variables_ever.append(r)
fgraph._eq_tracker_reasons.setdefault(r, [])
fgraph._eq_tracker_replaced_by.setdefault(r, [])
for r in node.inputs:
self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
fgraph._eq_tracker_reasons.setdefault(r, [])
fgraph._eq_tracker_replaced_by.setdefault(r, [])

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
reason = str(reason)
self.event_list.append(
fgraph._eq_tracker_event_list.append(
_FunctionGraphEvent("change", node, reason=reason, idx=i)
)

self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, [])
fgraph._eq_tracker_reasons.setdefault(new_r, [])
fgraph._eq_tracker_replaced_by.setdefault(new_r, [])

append_reason = True
for tup in self.reasons[new_r]:
for tup in fgraph._eq_tracker_reasons[new_r]:
if tup[0] == reason and tup[1] is r:
append_reason = False

Expand All @@ -1279,7 +1267,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
# optimizations will change the graph
done = dict()
used_ids = dict()
self.reasons[new_r].append(
fgraph._eq_tracker_reasons[new_r].append(
(
reason,
r,
Expand All @@ -1303,19 +1291,19 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
).getvalue(),
)
)
self.replaced_by[r].append((reason, new_r))
fgraph._eq_tracker_replaced_by[r].append((reason, new_r))

if r in self.equiv:
r_set = self.equiv[r]
if r in fgraph._eq_tracker_equiv:
r_set = fgraph._eq_tracker_equiv[r]
else:
r_set = self.equiv.setdefault(r, {r})
self.all_variables_ever.append(r)
r_set = fgraph._eq_tracker_equiv.setdefault(r, {r})
fgraph._eq_tracker_all_variables_ever.append(r)

if new_r in self.equiv:
new_r_set = self.equiv[new_r]
if new_r in fgraph._eq_tracker_equiv:
new_r_set = fgraph._eq_tracker_equiv[new_r]
else:
new_r_set = self.equiv.setdefault(new_r, {new_r})
self.all_variables_ever.append(new_r)
new_r_set = fgraph._eq_tracker_equiv.setdefault(new_r, {new_r})
fgraph._eq_tracker_all_variables_ever.append(new_r)

assert new_r in new_r_set
assert r in r_set
Expand All @@ -1324,17 +1312,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
# transfer all the elements of the old one to the new one
r_set.update(new_r_set)
for like_new_r in new_r_set:
self.equiv[like_new_r] = r_set
fgraph._eq_tracker_equiv[like_new_r] = r_set
assert like_new_r in r_set

assert self.equiv[r] is r_set
assert self.equiv[new_r] is r_set

def printstuff(self):
for key in self.equiv:
print(key)
for e in self.equiv[key]:
print(" ", e)
assert fgraph._eq_tracker_equiv[r] is r_set
assert fgraph._eq_tracker_equiv[new_r] is r_set


# List of default version of make thunk.
Expand Down Expand Up @@ -1390,9 +1372,7 @@ def make_all(
# Compute a topological ordering that IGNORES the destroy_map
# of destructive Ops. This will be OK, because every thunk is
# evaluated on a copy of its input.
fgraph_equiv = fgraph.equivalence_tracker
order_outputs = copy.copy(fgraph_equiv.all_variables_ever)
del fgraph_equiv
order_outputs = copy.copy(fgraph._eq_tracker_all_variables_ever)
order_outputs.reverse()
order = io_toposort(fgraph.inputs, order_outputs)

Expand Down Expand Up @@ -1625,7 +1605,7 @@ def f():
# insert a given apply node. If that is not True,
# we would need to loop over all node outputs,
# But this make the output uglier.
reason = fgraph.equivalence_tracker.reasons[node.outputs[0]]
reason = fgraph._eq_tracker_reasons[node.outputs[0]]
if not reason:
raise
opt = str(reason[0][0])
Expand Down Expand Up @@ -1738,7 +1718,7 @@ def f():
# insert a given apply node. If that is not True,
# we would need to loop over all node outputs,
# But this make the output uglier.
reason = fgraph.equivalence_tracker.reasons[node.outputs[0]]
reason = fgraph._eq_tracker_reasons[node.outputs[0]]
if not reason:
raise
opt = str(reason[0][0])
Expand Down Expand Up @@ -1865,9 +1845,7 @@ def thunk():
# But it is very slow and it is not sure it will help.
gc.collect()

_find_bad_optimizations(
order, fgraph.equivalence_tracker.reasons, r_vals
)
_find_bad_optimizations(order, fgraph._eq_tracker_reasons, r_vals)

#####
# Postcondition: the input and output variables are
Expand Down Expand Up @@ -2058,10 +2036,9 @@ def __init__(

# make the fgraph
for i in range(mode.stability_patience):
fgraph, additional_outputs, equivalence_tracker = _optcheck_fgraph(
fgraph, additional_outputs = _optcheck_fgraph(
inputs, outputs, accept_inplace
)
fgraph.equivalence_tracker = equivalence_tracker

with config.change_flags(compute_test_value=config.compute_test_value_opt):
optimizer(fgraph)
Expand All @@ -2073,8 +2050,8 @@ def __init__(
if i == 0:
fgraph0 = fgraph
else:
li = fgraph.equivalence_tracker.event_list
l0 = fgraph0.equivalence_tracker.event_list
li = fgraph._eq_tracker_event_list
l0 = fgraph0._eq_tracker_event_list
if li != l0:
infolog = StringIO()
print("Optimization process is unstable...", file=infolog)
Expand Down

0 comments on commit e32dc33

Please sign in to comment.