-
-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A simple solution to work on the copy of the graph #15
Conversation
Tests seem to be broken, do they pass localy? Solution looks nice, are there any corner cases? What's happening when we use nested models? |
Also what if |
@ferrine, some problems with using the the right graph with
|
Looks nice - once tests are fixed, it might be worth wrapping this in its own decorator or method. |
pymc4/model/base.py
Outdated
with ed.interception(interceptor): | ||
temp_graph = tf.Graph() | ||
tf.contrib.graph_editor.copy(self.graph, temp_graph) | ||
with temp_graph.as_default(), ed.interception(interceptor): | ||
self._f(self._cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this PR, but this line should probably be self._f(self.cfg)
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it shouldnt make a difference. self.cfg
returns self._cfg
. Although I am not sure why we are saving it as _cfg
The CI is failing because they removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for late review!
@@ -38,6 +38,7 @@ def __init__(self, name=None, graph=None, session=None, **config): | |||
session = tf.Session(graph=graph) | |||
self.session = session | |||
self.observe(**config) | |||
self.temp_graph = tf.Graph() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use this pattern for creating and destroying the temporary graph. I'm using a pymc3 model here because I think the pattern is similar, but let me know if it isn't! Note that you will only need to implement the equivalent of the temp_model
method below.
import pymc3 as pm
def check_in_model_context():
try:
m = pm.modelcontext(None)
return f'In model context! {len(m.vars)} vars!'
except TypeError:
return 'Not in model context!'
class Model(object):
def __init__(self):
self._temp_model = None
@contextmanager
def temp_model(self):
self._temp_model = pm.Model()
try:
with self._temp_model:
yield
finally:
self._temp_model = None
def calculation(self):
print(check_in_model_context())
with self.temp_model():
pm.Normal('x', 0, 1)
print(check_in_model_context())
print(check_in_model_context())
with self.temp_model():
print(check_in_model_context())
print(check_in_model_context())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decorator for inclass function is causing problems.. I will look into
@@ -12,7 +12,7 @@ | |||
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv') | |||
|
|||
|
|||
class Interceptor(object): | |||
class Interceptor(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good practice to inherit from object, even if you don't need to!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? I thought in py3 we don't need that anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am facing a linting error when I import from object
Explicit is better than implicit
вс, 29 июл. 2018 г., 15:03 Thomas Wiecki <[email protected]>:
… ***@***.**** commented on this pull request.
------------------------------
In pymc4/util/interceptors.py
<#15 (comment)>:
> @@ -12,7 +12,7 @@
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv')
-class Interceptor(object):
+class Interceptor():
Why? I thought in py3 we don't need that anymore.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#15 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/ALKb7jmrRGmUrnjPtKeGaI4wIWyTBUlkks5uLaSNgaJpZM4Vffeo>
.
|
There is nothing explicit about this, it's a python 2 hack to get super working (or something like that). |
Sort of a jokey discussion, but some nuggets of wisdom in there, notably that there aren't strong feelings either way: I'm fine leaving it out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The proposed approach has important corner cases not supported
@@ -54,6 +55,7 @@ def _init_variables(self): | |||
info_collector = interceptors.CollectVariablesInfo() | |||
with self.graph.as_default(), ed.interception(info_collector): | |||
self._f(self.cfg) | |||
tf.contrib.graph_editor.copy(self.graph, self.temp_graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So 1) You create a temp graph for inner manipulations
with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)): | ||
tf.contrib.graph_editor.copy(self.graph, self.temp_graph) | ||
# pylint: disable=not-context-manager | ||
with self.temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You reuse tmp graph created on step 1) suppose cfg
has any tensor from original graph and you call your model with temp_graph.as_default()
. That creates all variables for temp_graph
. This looks fine unless the case I'm talking about. If you attempt to do an operation on two variables from different graphs, you will get an exception
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not know a solution for this case yet. I've seen some kind of working and relevant example in tensorflow codebase
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/function.py#L347
Closing as the main prototype is from the london branch. |
No description provided.