diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index adede230ed9..02122d7948d 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -134,20 +134,22 @@ def _make_node(self, var_name, graph): if isinstance(v, SharedVariable): attrs["style"] = "rounded, filled" - # Get name for node + # determine the shape for this node (default (Distribution) is ellipse) if v in self.model.potentials: - distribution = "Potential" - attrs["shape"] = "octagon" - elif hasattr(v, "distribution"): - distribution = v.distribution.__class__.__name__ + attrs['shape'] = 'octagon' + elif isinstance(v, SharedVariable) or not hasattr(v, 'distribution'): + # shared variables and Deterministic represented by a box + attrs['shape'] = 'box' + + if v in self.model.potentials: + label = f'{var_name}\n~\nPotential' elif isinstance(v, SharedVariable): - distribution = "Data" - attrs["shape"] = "box" + label = f'{var_name}\n~\nData' else: - distribution = "Deterministic" - attrs["shape"] = "box" + label = str(v).replace(' ~ ', '\n~\n') + + graph.node(var_name.replace(':', '&'), label, **attrs) - graph.node(var_name.replace(":", "&"), f"{var_name}\n~\n{distribution}", **attrs) def get_plates(self): """Rough but surprisingly accurate plate detection. diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index b4c3e0009a4..c9ee9e3dd07 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -13,6 +13,7 @@ # limitations under the License. import pymc3 as pm +from ..theanof import floatX from .helpers import SeededTest import numpy as np import pandas as pd @@ -174,7 +175,8 @@ def test_model_to_graphviz_for_model_with_data_container(self): x = pm.Data("x", [1.0, 2.0, 3.0]) y = pm.Data("y", [1.0, 2.0, 3.0]) beta = pm.Normal("beta", 0, 10.0) - pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y) + obs_sigma = floatX(np.sqrt(1e-2)) + pm.Normal("obs", beta * x, obs_sigma, observed=y) pm.sample(1000, init=None, tune=1000, chains=1) g = pm.model_to_graphviz(model) @@ -183,9 +185,9 @@ def test_model_to_graphviz_for_model_with_data_container(self): text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]' assert text in g.source # Didn't break ordinary variables? - text = 'beta [label="beta\n~\nNormal"]' + text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]' assert text in g.source - text = 'obs [label="obs\n~\nNormal" style=filled]' + text = f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]' assert text in g.source def test_explicit_coords(self):