Skip to content

Commit

Permalink
Attach static methods in Validator and ReplaceValidate
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 21, 2022
1 parent 05e0287 commit 17831c6
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions aesara/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
import types
from collections import OrderedDict
from functools import partial
from io import StringIO

import numpy as np
Expand Down Expand Up @@ -441,12 +440,12 @@ def on_attach(self, fgraph):
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.validate = partial(self.validate_, fgraph)
fgraph.consistent = partial(self.consistent_, fgraph)
fgraph.validate = types.MethodType(self.validate_, fgraph)
fgraph.consistent = types.MethodType(self.consistent_, fgraph)

def unpickle(self, fgraph):
fgraph.validate = partial(self.validate_, fgraph)
fgraph.consistent = partial(self.consistent_, fgraph)
fgraph.validate = types.MethodType(self.validate_, fgraph)
fgraph.consistent = types.MethodType(self.consistent_, fgraph)

def on_detach(self, fgraph):
"""
Expand All @@ -456,7 +455,8 @@ def on_detach(self, fgraph):
del fgraph.validate
del fgraph.consistent

def validate_(self, fgraph):
@staticmethod
def validate_(fgraph):
"""
If the caller is replace_all_validate, just raise the
exception. replace_all_validate will print out the
Expand Down Expand Up @@ -488,7 +488,8 @@ def validate_(self, fgraph):
fgraph.profile.validate_time += t1 - t0
return ret

def consistent_(self, fgraph):
@staticmethod
def consistent_(fgraph):
try:
fgraph.validate()
return True
Expand Down Expand Up @@ -544,12 +545,14 @@ def on_detach(self, fgraph):
del fgraph.replace_all_validate
del fgraph.replace_all_validate_remove

def replace_validate(self, fgraph, r, new_r, reason=None, **kwargs):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason, **kwargs)
@staticmethod
def replace_validate(fgraph, r, new_r, reason=None, **kwargs):
ReplaceValidate.replace_all_validate(
fgraph, [(r, new_r)], reason=reason, **kwargs
)

def replace_all_validate(
self, fgraph, replacements, reason=None, verbose=None, **kwargs
):
@staticmethod
def replace_all_validate(fgraph, replacements, reason=None, verbose=None, **kwargs):
chk = fgraph.checkpoint()

if verbose is None:
Expand Down Expand Up @@ -602,8 +605,9 @@ def replace_all_validate(
# The return is needed by replace_all_validate_remove
return chk

@staticmethod
def replace_all_validate_remove(
self, fgraph, replacements, remove, reason=None, warn=True, **kwargs
fgraph, replacements, remove, reason=None, warn=True, **kwargs
):
"""
As replace_all_validate, revert the replacement if the ops
Expand Down

0 comments on commit 17831c6

Please sign in to comment.