Skip to content

Commit

Permalink
Merge pull request #20 from mrMakaronka/tf-2.5-support
Browse files Browse the repository at this point in the history
add StackSummary reducer only for tensorflow < 2.5
  • Loading branch information
AlvinMax authored May 19, 2021
2 parents 250f5a8 + 1e347dc commit 4650fa3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'pybase64>=1.0.0,<=1.0.2',
'cython==0.29.5',
'pympler==0.9',
'packaging==20.9',
],
tests_require=[
'numpy',
Expand Down
19 changes: 14 additions & 5 deletions src/main/python/ipystate/impl/dispatch/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import sys

from ipystate.impl.dispatch.dispatcher import Dispatcher
import tensorflow as tf
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils
import os
import json
import pybase64
from packaging import version


class TensorflowDispatcher(Dispatcher):
Expand Down Expand Up @@ -113,10 +116,16 @@ def register(self, dispatch):
dispatch[tf.Operation] = self._reduce_tf_op
dispatch[tf.keras.Model] = self._reduce_tf_model
dispatch[tf.keras.Sequential] = self._reduce_tf_model
if int(tf.__version__.split('.')[0]) <= 1:
if version.parse(tf.__version__) < version.parse('2.0.0'):
pass
else:
from tensorflow.python.ops.variable_scope import _VariableScopeStore
dispatch[_VariableScopeStore] = self._reduce_without_args(_VariableScopeStore)
from tensorflow.python._tf_stack import StackSummary
dispatch[StackSummary] = self._reduce_without_args(StackSummary)
try:
from tensorflow.python.ops.variable_scope import _VariableScopeStore
dispatch[_VariableScopeStore] = self._reduce_without_args(_VariableScopeStore)
if version.parse(tf.__version__) < version.parse('2.5.0'):
from tensorflow.python._tf_stack import StackSummary
dispatch[StackSummary] = self._reduce_without_args(StackSummary)
except ModuleNotFoundError:
print(
"Warning: some TensorFlow objects may not be serialized. Try to use TensorFlow 1.5 or 2.3 for full compatibility.",
file=sys.stderr)

0 comments on commit 4650fa3

Please sign in to comment.