From 44d600acdc8c894c3a8d02b9938d09d000bf9b95 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Wed, 7 Oct 2020 16:01:34 +0200 Subject: [PATCH 1/7] use new str() representations in GraphViz output --- pymc3/model_graph.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index 93fc3a8aab3..1be77d33227 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -133,22 +133,21 @@ def _make_node(self, var_name, graph): if isinstance(v, SharedVariable): attrs['style'] = 'rounded, filled' - # Get name for node + # determine the shape for this node (default (Distribution) is ellipse) if v in self.model.potentials: - distribution = 'Potential' attrs['shape'] = 'octagon' - elif hasattr(v, 'distribution'): - distribution = v.distribution.__class__.__name__ - elif isinstance(v, SharedVariable): - distribution = 'Data' + elif isinstance(v, SharedVariable) or not hasattr(v, 'distribution'): + # shared variables and Deterministic represented by a box attrs['shape'] = 'box' + + if v in self.model.potentials: + label = f'{var_name}\n~\nPotential' + elif isinstance(v, SharedVariable): + label = f'{var_name}\n~\Data' else: - distribution = 'Deterministic' - attrs['shape'] = 'box' + label = str(v).replace(' ~ ', '\n~\n') - graph.node(var_name.replace(':', '&'), - f'{var_name}\n~\n{distribution}', - **attrs) + graph.node(var_name.replace(':', '&'), label, **attrs) def get_plates(self): """ Rough but surprisingly accurate plate detection. From 006e930c177c00a7a97ab462d4ae6c6fa8fd6d07 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Wed, 7 Oct 2020 16:08:07 +0200 Subject: [PATCH 2/7] typo --- pymc3/model_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index 1be77d33227..a5793790e70 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -143,7 +143,7 @@ def _make_node(self, var_name, graph): if v in self.model.potentials: label = f'{var_name}\n~\nPotential' elif isinstance(v, SharedVariable): - label = f'{var_name}\n~\Data' + label = f'{var_name}\n~\nData' else: label = str(v).replace(' ~ ', '\n~\n') From c52547fe649ccb8025bb27aa7493cf659c3b055b Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Wed, 7 Oct 2020 17:19:56 +0200 Subject: [PATCH 3/7] updating tests --- pymc3/tests/test_data_container.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index b4c3e0009a4..5bc79360336 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -183,9 +183,9 @@ def test_model_to_graphviz_for_model_with_data_container(self): text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]' assert text in g.source # Didn't break ordinary variables? - text = 'beta [label="beta\n~\nNormal"]' + text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]' assert text in g.source - text = 'obs [label="obs\n~\nNormal" style=filled]' + text = 'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma=0.1)" style=filled]' assert text in g.source def test_explicit_coords(self): From 401686e96af4e53fc9006da34acd9dd6d9bb59a8 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Oct 2020 09:28:36 +0200 Subject: [PATCH 4/7] add debug print info for Travis-CI --- pymc3/tests/test_data_container.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index 5bc79360336..8098a271b4d 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -179,6 +179,7 @@ def test_model_to_graphviz_for_model_with_data_container(self): g = pm.model_to_graphviz(model) + print(g.source) # Data node rendered correctly? text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]' assert text in g.source From 7b60ab01316fae55d9dd23584925ccc031a0393f Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Oct 2020 10:04:46 +0200 Subject: [PATCH 5/7] fix failing test due to floating point numerical accuracy --- pymc3/tests/test_data_container.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index 8098a271b4d..8d6fa4efe4b 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -174,12 +174,11 @@ def test_model_to_graphviz_for_model_with_data_container(self): x = pm.Data("x", [1.0, 2.0, 3.0]) y = pm.Data("y", [1.0, 2.0, 3.0]) beta = pm.Normal("beta", 0, 10.0) - pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y) + pm.Normal("obs", beta * x, 0.1, observed=y) pm.sample(1000, init=None, tune=1000, chains=1) g = pm.model_to_graphviz(model) - print(g.source) # Data node rendered correctly? text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]' assert text in g.source From 44b6b14e0e7397a3a7eac22d224677a9d9b26ab2 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Oct 2020 10:29:38 +0200 Subject: [PATCH 6/7] another attempt at fixing float accuracy test fail --- pymc3/tests/test_data_container.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index 8d6fa4efe4b..668febc3e95 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -174,18 +174,20 @@ def test_model_to_graphviz_for_model_with_data_container(self): x = pm.Data("x", [1.0, 2.0, 3.0]) y = pm.Data("y", [1.0, 2.0, 3.0]) beta = pm.Normal("beta", 0, 10.0) - pm.Normal("obs", beta * x, 0.1, observed=y) + obs_sigma = np.sqrt(1e-2) + pm.Normal("obs", beta * x, obs_sigma, observed=y) pm.sample(1000, init=None, tune=1000, chains=1) g = pm.model_to_graphviz(model) + print(g.source) # Data node rendered correctly? text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]' assert text in g.source # Didn't break ordinary variables? text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]' assert text in g.source - text = 'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma=0.1)" style=filled]' + text = f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]' assert text in g.source def test_explicit_coords(self): From cccedbede7628f8729929538c0096f37cfa317e6 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Oct 2020 11:06:17 +0200 Subject: [PATCH 7/7] ensure string format as proper floatX --- pymc3/tests/test_data_container.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index 668febc3e95..c9ee9e3dd07 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -13,6 +13,7 @@ # limitations under the License. import pymc3 as pm +from ..theanof import floatX from .helpers import SeededTest import numpy as np import pandas as pd @@ -174,13 +175,12 @@ def test_model_to_graphviz_for_model_with_data_container(self): x = pm.Data("x", [1.0, 2.0, 3.0]) y = pm.Data("y", [1.0, 2.0, 3.0]) beta = pm.Normal("beta", 0, 10.0) - obs_sigma = np.sqrt(1e-2) + obs_sigma = floatX(np.sqrt(1e-2)) pm.Normal("obs", beta * x, obs_sigma, observed=y) pm.sample(1000, init=None, tune=1000, chains=1) g = pm.model_to_graphviz(model) - print(g.source) # Data node rendered correctly? text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]' assert text in g.source