From 6acc09f573a23d4ff01b103080c2e2764158e90c Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Thu, 28 Jul 2022 19:33:10 -0500 Subject: [PATCH 01/13] add dynamic bayesian network class --- causalnex/network/__init__.py | 1 + causalnex/network/network.py | 81 +++ tests/test_dynamicbayesiannetwork.py | 925 +++++++++++++++++++++++++++ 3 files changed, 1007 insertions(+) create mode 100644 tests/test_dynamicbayesiannetwork.py diff --git a/causalnex/network/__init__.py b/causalnex/network/__init__.py index 8b94a7a..0d9de37 100644 --- a/causalnex/network/__init__.py +++ b/causalnex/network/__init__.py @@ -33,3 +33,4 @@ __all__ = ["BayesianNetwork"] from .network import BayesianNetwork +from .network import DynamicBayesianNetwork diff --git a/causalnex/network/network.py b/causalnex/network/network.py index 908e55a..3a6a230 100644 --- a/causalnex/network/network.py +++ b/causalnex/network/network.py @@ -736,3 +736,84 @@ def _predict_probability_from_incomplete_data( probability = probability[cols] probability.columns = cols return probability + + +class DynamicBayesianNetwork(BayesianNetwork): + """ + Base class for Dynamic Bayesian Network (DBN), a probabilistic weighted DAG where nodes represent variables, + edges represent the causal relationships between variables. + + ``DynamicBayesianNetwork`` stores nodes with their possible states, edges and + conditional probability distributions (CPDs) of each node. + + ``DynamicBayesianNetwork`` is built on top of the ``StructureModel``, which is an extension of ``networkx.DiGraph`` + (see :func:`causalnex.structure.structuremodel.StructureModel`). + + In order to define the ``DynamicBayesianNetwork``, users should provide a relevant ``StructureModel``. + Once ``DynamicBayesianNetwork`` is initialised, no changes to the ``StructureModel`` can be made + and CPDs can be learned from the data. + + The learned CPDs can be then used for likelihood estimation and predictions. + + Example: + :: + >>> # Create a Dynamic Bayesian Network with a manually defined DAG. + >>> from causalnex.structure import StructureModel + >>> from causalnex.network import DynamicBayesianNetwork + >>> + >>> sm = StructureModel() + >>> sm.add_edges_from([ + >>> ('rush_hour', 'traffic'), + >>> ('weather', 'traffic') + >>> ]) + >>> dbn = DynamicBayesianNetwork(sm) + >>> # A created ``DynamicBayesianNetwork`` stores nodes and edges defined by the ``StructureModel`` + >>> dbn.nodes + ['rush_hour', 'traffic', 'weather'] + >>> + >>> dbn.edges + [('rush_hour', 'traffic'), ('weather', 'traffic')] + >>> # A ``DynamicBayesianNetwork`` doesn't store any CPDs yet + >>> dbn.cpds + >>> {} + >>> + >>> # Learn the nodes' states from the data + >>> import pandas as pd + >>> data = pd.DataFrame({ + >>> 'rush_hour': [True, False, False, False, True, False, True], + >>> 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], + >>> 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] + >>> }) + >>> dbn = dbn.fit_node_states(data) + >>> dbn.node_states + {'rush_hour': {False, True}, 'weather': {'Bad', 'Good', 'Terrible'}, 'traffic': {'heavy', 'light'}} + >>> # Learn the CPDs from the data + >>> dbn = dbn.fit_cpds(data) + >>> # Use the learned CPDs to make predictions on the unseen data + >>> test_data = pd.DataFrame({ + >>> 'rush_hour': [False, False, True, True], + >>> 'weather': ['Good', 'Bad', 'Good', 'Bad'] + >>> }) + >>> dbn.predict(test_data, "traffic").to_dict() + >>> {'traffic_prediction': {0: 'light', 1: 'heavy', 2: 'heavy', 3: 'heavy'}} + >>> dbn.predict_probability(test_data, "traffic").to_dict() + {'traffic_prediction': {0: 'light', 1: 'heavy', 2: 'heavy', 3: 'heavy'}} + {'traffic_light': {0: 0.75, 1: 0.25, 2: 0.3333333333333333, 3: 0.3333333333333333}, + 'traffic_heavy': {0: 0.25, 1: 0.75, 2: 0.6666666666666666, 3: 0.6666666666666666}} + """ + + def __init__(self, structure: StructureModel): + """ + Create a ``DynamicBayesianNetwork`` with a DAG defined by ``StructureModel``. + + Args: + structure: a graph representing a causal relationship between variables. + In the structure + - cycles are not allowed; + - multiple (parallel) edges are not allowed; + - isolated nodes and multiple components are not allowed. + + Raises: + ValueError: If the structure is not a connected DAG. + """ + super().__init__(structure) \ No newline at end of file diff --git a/tests/test_dynamicbayesiannetwork.py b/tests/test_dynamicbayesiannetwork.py new file mode 100644 index 0000000..7e0ed1e --- /dev/null +++ b/tests/test_dynamicbayesiannetwork.py @@ -0,0 +1,925 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import pandas as pd +import pytest + +from causalnex.evaluation import classification_report, roc_auc +from causalnex.inference import InferenceEngine +from causalnex.network import DynamicBayesianNetwork +from causalnex.structure import StructureModel +from causalnex.structure.notears import from_pandas + +from .estimator.test_em import naive_bayes_plus_parents + + +class TestFitNodeStates: + """Test behaviour of fit node states method""" + + @pytest.mark.parametrize( + "weighted_edges, data", + [ + ([("a", "b", 1)], pd.DataFrame([[1, 1]], columns=["a", "b"])), + ( + [("a", "b", 1)], + pd.DataFrame([[1, 1, 1, 1]], columns=["a", "b", "c", "d"]), + ), + # c and d are isolated nodes in the data + ], + ) + def test_all_nodes_included(self, weighted_edges, data): + """No errors if all the nodes can be found in the columns of training data""" + cg = StructureModel() + cg.add_weighted_edges_from(weighted_edges) + bn = DynamicBayesianNetwork(cg).fit_node_states(data) + assert all(node in data.columns for node in bn.node_states.keys()) + + def test_all_states_included(self): + """All states in a node should be included""" + cg = StructureModel() + cg.add_weighted_edges_from([("a", "b", 1)]) + bn = DynamicBayesianNetwork(cg).fit_node_states( + pd.DataFrame([[i, i] for i in range(10)], columns=["a", "b"]) + ) + assert all(v in bn.node_states["a"] for v in range(10)) + + def test_fit_with_null_states_raises_error(self): + """An error should be raised if fit is called with null data""" + cg = StructureModel() + cg.add_weighted_edges_from([("a", "b", 1)]) + with pytest.raises(ValueError, match="node '.*' contains None state"): + DynamicBayesianNetwork(cg).fit_node_states( + pd.DataFrame([[None, 1]], columns=["a", "b"]) + ) + + def test_fit_with_missing_feature_in_data(self): + """An error should be raised if fit is called with missing feature in data""" + cg = StructureModel() + + cg.add_weighted_edges_from([("a", "e", 1)]) + with pytest.raises( + KeyError, + match="The data does not cover all the features found in the Bayesian Network. " + "Please check the following features: {'e'}", + ): + DynamicBayesianNetwork(cg).fit_node_states( + pd.DataFrame([[1, 1, 1, 1]], columns=["a", "b", "c", "d"]) + ) + + +class TestFitCPDSErrors: + """Test errors for fit CPDs method""" + + def test_invalid_method(self, bn, train_data_discrete): + """a value error should be raised in an invalid method is provided""" + + with pytest.raises(ValueError, match=r"unrecognised method.*"): + bn.fit_cpds(train_data_discrete, method="INVALID") + + def test_invalid_prior(self, bn, train_data_discrete): + """a value error should be raised in an invalid prior is provided""" + + with pytest.raises(ValueError, match=r"unrecognised bayes_prior.*"): + bn.fit_cpds( + train_data_discrete, method="BayesianEstimator", bayes_prior="INVALID" + ) + + +class TestFitCPDsMaximumLikelihoodEstimator: + """Test behaviour of fit_cpds using MLE""" + + def test_cause_only_node(self, bn, train_data_discrete, train_data_discrete_cpds): + """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" + + bn.fit_cpds(train_data_discrete) + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["d"].values.reshape(2) + - train_data_discrete_cpds["d"].reshape(2) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["e"].values.reshape(2) + - train_data_discrete_cpds["e"].reshape(2) + ) + ) + < 1e-7 + ) + + def test_dependent_node(self, bn, train_data_discrete, train_data_discrete_cpds): + """Test that probabilities are fit correctly to nodes that are caused by other nodes""" + + bn.fit_cpds(train_data_discrete) + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["a"].values.reshape(24) + - train_data_discrete_cpds["a"].reshape(24) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["b"].values.reshape(12) + - train_data_discrete_cpds["b"].reshape(12) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["c"].values.reshape(60) + - train_data_discrete_cpds["c"].reshape(60) + ) + ) + < 1e-7 + ) + + def test_fit_missing_states(self): + """test issues/15: should be possible to fit with missing states""" + + sm = StructureModel([("a", "b"), ("c", "b")]) + bn = DynamicBayesianNetwork(sm) + + train = pd.DataFrame( + data=[[0, 0, 1], [1, 0, 1], [1, 1, 1]], columns=["a", "b", "c"] + ) + test = pd.DataFrame( + data=[[0, 0, 1], [1, 0, 1], [1, 1, 2]], columns=["a", "b", "c"] + ) + data = pd.concat([train, test]) + + bn.fit_node_states(data) + bn.fit_cpds(train) + + assert bn.cpds["c"].loc[1][0] == 1 + assert bn.cpds["c"].loc[2][0] == 0 + + +class TestFitBayesianEstimator: + """Test behaviour of fit_cpds using BE""" + + def test_cause_only_node_bdeu( + self, bn, train_data_discrete, train_data_discrete_cpds + ): + """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="BDeu", + equivalent_sample_size=5, + ) + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["d"].values.reshape(2) + - train_data_discrete_cpds["d"].reshape(2) + ) + ) + < 0.02 + ) + assert ( + np.mean( + np.abs( + cpds["e"].values.reshape(2) + - train_data_discrete_cpds["e"].reshape(2) + ) + ) + < 0.02 + ) + + def test_cause_only_node_k2( + self, bn, train_data_discrete, train_data_discrete_cpds + ): + """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" + + bn.fit_cpds(train_data_discrete, method="BayesianEstimator", bayes_prior="K2") + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["d"].values.reshape(2) + - train_data_discrete_cpds["d"].reshape(2) + ) + ) + < 0.02 + ) + assert ( + np.mean( + np.abs( + cpds["e"].values.reshape(2) + - train_data_discrete_cpds["e"].reshape(2) + ) + ) + < 0.02 + ) + + def test_dependent_node_bdeu( + self, bn, train_data_discrete, train_data_discrete_cpds + ): + """Test that probabilities are fit correctly to nodes that are caused by other nodes""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="BDeu", + equivalent_sample_size=1, + ) + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["a"].values.reshape(24) + - train_data_discrete_cpds["a"].reshape(24) + ) + ) + < 0.02 + ) + assert ( + np.mean( + np.abs( + cpds["b"].values.reshape(12) + - train_data_discrete_cpds["b"].reshape(12) + ) + ) + < 0.02 + ) + assert ( + np.mean( + np.abs( + cpds["c"].values.reshape(60) + - train_data_discrete_cpds["c"].reshape(60) + ) + ) + < 0.02 + ) + + def test_dependent_node_k2( + self, bn, train_data_discrete, train_data_discrete_cpds_k2 + ): + """Test that probabilities are fit correctly to nodes that are caused by other nodes""" + + bn.fit_cpds(train_data_discrete, method="BayesianEstimator", bayes_prior="K2") + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["a"].values.reshape(24) + - train_data_discrete_cpds_k2["a"].reshape(24) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["b"].values.reshape(12) + - train_data_discrete_cpds_k2["b"].reshape(12) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["c"].values.reshape(60) + - train_data_discrete_cpds_k2["c"].reshape(60) + ) + ) + < 1e-7 + ) + + def test_fit_missing_states(self): + """test issues/15: should be possible to fit with missing states""" + + sm = StructureModel([("a", "b"), ("c", "b")]) + bn = DynamicBayesianNetwork(sm) + + train = pd.DataFrame( + data=[[0, 0, 1], [1, 0, 1], [1, 1, 1]], columns=["a", "b", "c"] + ) + test = pd.DataFrame( + data=[[0, 0, 1], [1, 0, 1], [1, 1, 2]], columns=["a", "b", "c"] + ) + data = pd.concat([train, test]) + + bn.fit_node_states(data) + bn.fit_cpds(train, method="BayesianEstimator", bayes_prior="K2") + + assert bn.cpds["c"].loc[1][0] == 0.8 + assert bn.cpds["c"].loc[2][0] == 0.2 + + +class TestPredictMaximumLikelihoodEstimator: + """Test behaviour of predict using MLE""" + + def test_predictions_are_based_on_probabilities( + self, bn, train_data_discrete, test_data_c_discrete + ): + """Predictions made using the model should be based on the probabilities that are in the model""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete, "c") + assert np.all( + predictions.values.reshape(len(predictions.values)) + == test_data_c_discrete["c"].values + ) + + def test_prediction_node_suffixed_as_prediction( + self, bn, train_data_discrete, test_data_c_discrete + ): + """The column that contains the values of the predicted node should be named node_prediction""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete, "c") + assert "c_prediction" in predictions.columns + + def test_only_predicted_column_returned( + self, bn, train_data_discrete, test_data_c_discrete + ): + """The returned df should not contain any of the input data columns""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete, "c") + assert len(predictions.columns) == 1 + + def test_predictions_are_not_appended_to_input_df( + self, bn, train_data_discrete, test_data_c_discrete + ): + """The predictions should not be appended to the input df""" + + expected_cols = test_data_c_discrete.columns + bn.fit_cpds(train_data_discrete) + bn.predict(test_data_c_discrete, "c") + assert np.array_equal(test_data_c_discrete.columns, expected_cols) + + def test_missing_parent(self, bn, train_data_discrete, test_data_c_discrete): + """Predictions made when parents are missing should still be reasonably accurate""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete[["a", "b", "c", "d"]], "c") + + n = len(test_data_c_discrete) + + accuracy = ( + 1 + - np.count_nonzero( + predictions.values.reshape(len(predictions.values)) + - test_data_c_discrete["c"].values + ) + / n + ) + + assert accuracy > 0.9 + + def test_missing_non_parent(self, bn, train_data_discrete, test_data_c_discrete): + """It should be possible to make predictions with non-parent nodes missing""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete[["b", "c", "d", "e"]], "c") + assert np.all( + predictions.values.reshape(len(predictions.values)) + == test_data_c_discrete["c"].values + ) + + +class TestPredictBayesianEstimator: + """Test behaviour of predict using BE""" + + def test_predictions_are_based_on_probabilities_dbeu( + self, bn, train_data_discrete, test_data_c_discrete + ): + """Predictions made using the model should be based on the probabilities that are in the model""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="BDeu", + equivalent_sample_size=5, + ) + predictions = bn.predict(test_data_c_discrete, "c") + assert np.all( + predictions.values.reshape(len(predictions.values)) + == test_data_c_discrete["c"].values + ) + + def test_predictions_are_based_on_probabilities_k2( + self, bn, train_data_discrete, test_data_c_discrete + ): + """Predictions made using the model should be based on the probabilities that are in the model""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="K2", + equivalent_sample_size=5, + ) + predictions = bn.predict(test_data_c_discrete, "c") + assert np.all( + predictions.values.reshape(len(predictions.values)) + == test_data_c_discrete["c"].values + ) + + +class TestPredictProbabilityMaximumLikelihoodEstimator: + """Test behaviour of predict_probability using MLE""" + + def test_expected_probabilities_are_predicted( + self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood + ): + """Probabilities should return exactly correct on a hand computable scenario""" + + bn.fit_cpds(train_data_discrete) + probability = bn.predict_probability(test_data_c_discrete, "c") + + assert all( + np.isclose( + probability.values.flatten(), test_data_c_likelihood.values.flatten() + ) + ) + + def test_missing_parent( + self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood + ): + """Probabilities made when parents are missing should still be reasonably accurate""" + + bn.fit_cpds(train_data_discrete) + probability = bn.predict_probability( + test_data_c_discrete[["a", "b", "c", "d"]], "c" + ) + + n = len(probability.values.flatten()) + + accuracy = ( + np.count_nonzero( + [ + 1 if math.isclose(a, b, abs_tol=0.15) else 0 + for a, b in zip( + probability.values.flatten(), + test_data_c_likelihood.values.flatten(), + ) + ] + ) + / n + ) + + assert accuracy > 0.8 + + def test_missing_non_parent( + self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood + ): + """It should be possible to make predictions with non-parent nodes missing""" + + bn.fit_cpds(train_data_discrete) + probability = bn.predict_probability( + test_data_c_discrete[["b", "c", "d", "e"]], "c" + ) + assert all( + np.isclose( + probability.values.flatten(), test_data_c_likelihood.values.flatten() + ) + ) + + +class TestPredictProbabilityBayesianEstimator: + """Test behaviour of predict_probability using BayesianEstimator""" + + def test_expected_probabilities_are_predicted( + self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood + ): + """Probabilities should return exactly correct on a hand computable scenario""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="BDeu", + equivalent_sample_size=1, + ) + probability = bn.predict_probability(test_data_c_discrete, "c") + assert all( + np.isclose( + probability.values.flatten(), + test_data_c_likelihood.values.flatten(), + atol=0.1, + ) + ) + + +class TestFitNodesStatesAndCPDs: + """Test behaviour of helper function""" + + def test_behaves_same_as_separate_calls(self, train_data_idx, train_data_discrete): + bn1 = DynamicBayesianNetwork(from_pandas(train_data_idx, w_threshold=0.3)) + bn2 = DynamicBayesianNetwork(from_pandas(train_data_idx, w_threshold=0.3)) + + bn1.fit_node_states(train_data_discrete).fit_cpds(train_data_discrete) + bn2.fit_node_states_and_cpds(train_data_discrete) + + assert bn1.edges == bn2.edges + assert bn1.node_states == bn2.node_states + + cpds1 = bn1.cpds + cpds2 = bn2.cpds + + assert cpds1.keys() == cpds2.keys() + + for k, df in cpds1.items(): + assert df.equals(cpds2[k]) + + +class TestLatentVariable: + @staticmethod + def mean_absolute_error(cpds_a, cpds_b): + """Compute the absolute error among each single parameter and average them out""" + + mae = 0 + n_param = 0 + + for node in cpds_a.keys(): + err = np.abs(cpds_a[node] - cpds_b[node]).values + mae += np.sum(err) + n_param += err.shape[0] * err.shape[1] + + return mae / n_param + + def test_em_algorithm(self): # pylint: disable=too-many-locals + """ + Test if `DynamicBayesianNetwork` works with EM algorithm. + We use a naive bayes + parents + an extra node not related to the latent variable. + """ + + # p0 p1 p2 + # \ | / + # z + # / | \ + # c0 c1 c2 + # | + # cc0 + np.random.seed(22) + + data, sm, _, true_lv_values = naive_bayes_plus_parents( + percentage_not_missing=0.1, + samples=1000, + p_z=0.7, + p_c=0.7, + ) + data["cc_0"] = np.where( + np.random.random(len(data)) < 0.5, data["c_0"], (data["c_0"] + 1) % 3 + ) + data.drop(columns=["z"], inplace=True) + + complete_data = data.copy(deep=True) + complete_data["z"] = true_lv_values + + # Baseline model: the structure of the figure trained with complete data. We try to reproduce it + complete_bn = DynamicBayesianNetwork( + StructureModel(list(sm.edges) + [("c_0", "cc_0")]) + ) + complete_bn.fit_node_states_and_cpds(complete_data) + + # BN without latent variable: All `p`s are connected to all `c`s + `c0` ->`cc0` + sm_no_lv = StructureModel( + [(f"p_{p}", f"c_{c}") for p in range(3) for c in range(3)] + + [("c_0", "cc_0")] + ) + bn = DynamicBayesianNetwork(sm_no_lv) + bn.fit_node_states(data) + bn.fit_cpds(data) + + # TEST 1: cc_0 does not depend on the latent variable so: + assert np.all(bn.cpds["cc_0"] == complete_bn.cpds["cc_0"]) + + # BN with latent variable + # When we add the latent variable, we add the edges in the image above + # and remove the connection among `p`s and `c`s + edges_to_add = list(sm.edges) + edges_to_remove = [(f"p_{p}", f"c_{c}") for p in range(3) for c in range(3)] + bn.add_node("z", edges_to_add, edges_to_remove) + bn.fit_latent_cpds("z", [0, 1, 2], data, stopping_delta=0.001) + + # TEST 2: cc_0 CPD should remain untouched by the EM algorithm + assert np.all(bn.cpds["cc_0"] == complete_bn.cpds["cc_0"]) + + # TEST 3: We should recover the correct CPDs quite accurately + assert bn.cpds.keys() == complete_bn.cpds.keys() + assert self.mean_absolute_error(bn.cpds, complete_bn.cpds) < 0.01 + + # TEST 4: Inference over recovered CPDs should be also accurate + eng = InferenceEngine(bn) + query = eng.query() + n_rows = complete_data.shape[0] + + for node in query: + assert ( + np.abs(query[node][0] - sum(complete_data[node] == 0) / n_rows) < 1e-2 + ) + assert ( + np.abs(query[node][1] - sum(complete_data[node] == 1) / n_rows) < 1e-2 + ) + + # TEST 5: Inference using predict and predict_probability functions + report = classification_report(bn, complete_data, "z") + _, auc = roc_auc(bn, complete_data, "z") + complete_report = classification_report(complete_bn, complete_data, "z") + _, complete_auc = roc_auc(complete_bn, complete_data, "z") + + for category, metrics in report.items(): + if isinstance(metrics, dict): + for key, val in metrics.items(): + assert np.abs(val - complete_report[category][key]) < 1e-2 + else: + assert np.abs(metrics - complete_report[category]) < 1e-2 + + assert np.abs(auc - complete_auc) < 1e-2 + + +class TestAddNode: + def test_add_node_not_in_edges_to_add(self): + """An error should be raised if the latent variable is NOT part of the edges to add""" + + with pytest.raises( + ValueError, + match="Should only add edges containing node 'd'", + ): + _, sm, _, _ = naive_bayes_plus_parents() + sm = StructureModel(list(sm.edges)) + bn = DynamicBayesianNetwork(sm) + bn.add_node("d", [("a", "z"), ("b", "z")], []) + + def test_add_node_in_edges_to_remove(self): + """An error should be raised if the latent variable is part of the edges to remove""" + + with pytest.raises( + ValueError, + match="Should only remove edges NOT containing node 'd'", + ): + _, sm, _, _ = naive_bayes_plus_parents() + sm = StructureModel(list(sm.edges)) + bn = DynamicBayesianNetwork(sm) + bn.add_node("d", [], [("a", "d"), ("b", "d")]) + + +class TestFitLatentCPDs: + @pytest.mark.parametrize("lv_name", [None, [], set(), {}, tuple(), 123, {}]) + def test_fit_invalid_lv_name(self, lv_name): + """An error should be raised if the latent variable is of an invalid type""" + + with pytest.raises( + ValueError, + match=r"Invalid latent variable name *", + ): + df, sm, _, _ = naive_bayes_plus_parents() + sm = StructureModel(list(sm.edges)) + bn = DynamicBayesianNetwork(sm) + bn.fit_latent_cpds(lv_name, [0, 1, 2], df) + + def test_fit_lv_not_added(self): + """An error should be raised if the latent variable is not added to the network yet""" + + with pytest.raises( + ValueError, + match=r"Latent variable 'd' not added to the network", + ): + df, sm, _, _ = naive_bayes_plus_parents() + sm = StructureModel(list(sm.edges)) + bn = DynamicBayesianNetwork(sm) + bn.fit_latent_cpds("d", [0, 1, 2], df) + + @pytest.mark.parametrize("lv_states", [None, [], set(), {}]) + def test_fit_invalid_lv_states(self, lv_states): + """An error should be raised if the latent variable has invalid states""" + + with pytest.raises( + ValueError, + match="Latent variable 'd' contains no states", + ): + df, sm, _, _ = naive_bayes_plus_parents() + sm = StructureModel(list(sm.edges)) + bn = DynamicBayesianNetwork(sm) + bn.add_node("d", [("z", "d")], []) + bn.fit_latent_cpds("d", lv_states, df) + + +class TestSetCPD: + """Test behaviour of adding a self-defined cpd""" + + def test_set_cpd(self, bn, good_cpd): + """The CPD of the target node should be the same as the self-defined table after adding""" + + bn.set_cpd("b", good_cpd) + assert bn.cpds["b"].values.tolist() == good_cpd.values.tolist() + + def test_set_other_cpd(self, bn, good_cpd): + """The CPD of nodes other than the target node should not be affected""" + + cpd = bn.cpds["a"].values.tolist() + bn.set_cpd("b", good_cpd) + cpd_after_adding = bn.cpds["a"].values.tolist() + + assert all( + val == val_after_adding + for val, val_after_adding in zip(*(cpd, cpd_after_adding)) + ) + + def test_set_cpd_to_non_existent_node(self, bn, good_cpd): + """Should raise error if adding a cpd to a non-existing node in Bayesian Network""" + + with pytest.raises( + ValueError, + match=r'Non-existing node "test"', + ): + bn.set_cpd("test", good_cpd) + + def test_set_bad_cpd(self, bn, bad_cpd): + """Should raise error if it the prpbability values do not sum up to 1 in the table""" + + with pytest.raises( + ValueError, + match=r"Sum or integral of conditional probabilites for node b is not equal to 1.", + ): + bn.set_cpd("b", bad_cpd) + + def test_no_overwritten_after_setting_bad_cpd(self, bn, bad_cpd): + """The cpd of bn won't be overwritten if adding a bad cpd""" + + original_cpd = bn.cpds["b"].values.tolist() + + try: + bn.set_cpd("b", bad_cpd) + except ValueError: + assert bn.cpds["b"].values.tolist() == original_cpd + + def test_bad_node_index(self, bn, good_cpd): + """Should raise an error when setting bad node index""" + + bad_cpd = good_cpd + bad_cpd.index.name = "test" + + with pytest.raises( + IndexError, + match=r"Wrong index values. Please check your indices", + ): + bn.set_cpd("b", bad_cpd) + + def test_bad_node_states_index(self, bn, good_cpd): + """Should raise an error when setting bad node states index""" + + bad_cpd = good_cpd.reindex([1, 2, 3]) + + with pytest.raises( + IndexError, + match=r"Wrong index values. Please check your indices", + ): + bn.set_cpd("b", bad_cpd) + + def test_bad_parent_node_index(self, bn, good_cpd): + """Should raise an error when setting bad parent node index""" + + bad_cpd = good_cpd + bad_cpd.columns = bad_cpd.columns.rename("test", level=1) + + with pytest.raises( + IndexError, + match=r"Wrong index values. Please check your indices", + ): + bn.set_cpd("b", bad_cpd) + + def test_bad_parent_node_states_index(self, bn, good_cpd): + """Should raise an error when setting bad parent node states index""" + + bad_cpd = good_cpd + bad_cpd.columns.set_levels(["test1", "test2"], level=0, inplace=True) + + with pytest.raises( + IndexError, + match=r"Wrong index values. Please check your indices", + ): + bn.set_cpd("b", bad_cpd) + + +class TestCPDsProperty: + """Test behaviour of the CPDs property""" + + def test_row_index_of_state_values(self, bn): + """CPDs should have row index set to values of all possible states of the node""" + + assert bn.cpds["a"].index.tolist() == sorted(list(bn.node_states["a"])) + + def test_col_index_of_parent_state_combinations(self, bn): + """CPDs should have a column multi-index of parent state permutations""" + + assert bn.cpds["a"].columns.names == ["b", "d"] + + +class TestInit: + """Test behaviour when constructing a DynamicBayesianNetwork""" + + def test_cycles_in_structure(self): + """An error should be raised if cycles are present""" + + with pytest.raises( + ValueError, + match=r"The given structure is not acyclic\. " + r"Please review the following cycle\.*", + ): + DynamicBayesianNetwork(StructureModel([(0, 1), (1, 2), (2, 0)])) + + @pytest.mark.parametrize( + "test_input,n_components", + [([(0, 1), (1, 2), (3, 4), (4, 6)], 2), ([(0, 1), (1, 2), (3, 4), (5, 6)], 3)], + ) + def test_disconnected_components(self, test_input, n_components): + """An error should be raised if there is more than one graph component""" + + with pytest.raises( + ValueError, + match=r"The given structure has " + + str(n_components) + + r" separated graph components\. " + r"Please make sure it has only one\.", + ): + DynamicBayesianNetwork(StructureModel(test_input)) + + +class TestStructure: + """Test behaviour of the property structure""" + + def test_get_structure(self): + """The structure retrieved should be the same""" + + sm = StructureModel() + sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") + sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") + sm.add_weighted_edges_from([(3, 5, 0.7)], origin="expert") + + bn = DynamicBayesianNetwork(sm) + + sm_from_bn = bn.structure + + assert set(sm.edges.data("origin")) == set(sm_from_bn.edges.data("origin")) + assert set(sm.edges.data("weight")) == set(sm_from_bn.edges.data("weight")) + assert set(sm.nodes) == set(sm_from_bn.nodes) + + def test_set_structure(self): + """An error should be raised if setting the structure""" + + sm = StructureModel() + sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") + sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") + sm.add_weighted_edges_from([(3, 5, 0.7)], origin="expert") + + bn = DynamicBayesianNetwork(sm) + + new_sm = StructureModel() + sm.add_weighted_edges_from([(2, 5, 3.0)], origin="unknown") + sm.add_weighted_edges_from([(2, 3, 2.0)], origin="learned") + sm.add_weighted_edges_from([(3, 4, 1.7)], origin="expert") + + with pytest.raises(AttributeError, match=r"can't set attribute"): + bn.structure = new_sm From e1ac8e002b79bec615a0266fc812032413b74143 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Thu, 29 Sep 2022 19:31:29 -0500 Subject: [PATCH 02/13] push branch for Issue 121 --- causalnex/network/network.py | 6 +- causalnex/structure/__init__.py | 5 + causalnex/structure/dynotears.py | 21 +- causalnex/structure/structuremodel.py | 282 +++++++- tests/test_dynamicbayesiannetwork.py | 907 +------------------------- 5 files changed, 309 insertions(+), 912 deletions(-) diff --git a/causalnex/network/network.py b/causalnex/network/network.py index 3a6a230..3408059 100644 --- a/causalnex/network/network.py +++ b/causalnex/network/network.py @@ -43,7 +43,7 @@ from pgmpy.models import BayesianModel from causalnex.estimator.em import EMSingleLatentVariable -from causalnex.structure import StructureModel +from causalnex.structure import StructureModel, DynamicStructureModel from causalnex.utils.pgmpy_utils import pd_to_tabular_cpd @@ -802,9 +802,9 @@ class DynamicBayesianNetwork(BayesianNetwork): 'traffic_heavy': {0: 0.25, 1: 0.75, 2: 0.6666666666666666, 3: 0.6666666666666666}} """ - def __init__(self, structure: StructureModel): + def __init__(self, structure: DynamicStructureModel): """ - Create a ``DynamicBayesianNetwork`` with a DAG defined by ``StructureModel``. + Create a ``DynamicBayesianNetwork`` with a DAG defined by ``DynamicStructureModel``. Args: structure: a graph representing a causal relationship between variables. diff --git a/causalnex/structure/__init__.py b/causalnex/structure/__init__.py index e0f2315..4980400 100644 --- a/causalnex/structure/__init__.py +++ b/causalnex/structure/__init__.py @@ -35,9 +35,14 @@ "notears", "dynotears", "data_generators", + "node", "DAGRegressor", "DAGClassifier", + "DynamicStructureModel", + "DynamicStructureNode" ] from .pytorch.sklearn import DAGClassifier, DAGRegressor from .structuremodel import StructureModel +from .structuremodel import DynamicStructureModel +from .structuremodel import DynamicStructureNode \ No newline at end of file diff --git a/causalnex/structure/dynotears.py b/causalnex/structure/dynotears.py index 42faf29..65231ba 100644 --- a/causalnex/structure/dynotears.py +++ b/causalnex/structure/dynotears.py @@ -38,7 +38,8 @@ import scipy.linalg as slin import scipy.optimize as sopt -from causalnex.structure import StructureModel +from causalnex.structure import DynamicStructureModel +from causalnex.structure import DynamicStructureNode from causalnex.structure.transformers import DynamicDataTransformer @@ -53,7 +54,7 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments tabu_edges: List[Tuple[int, int, int]] = None, tabu_parent_nodes: List[int] = None, tabu_child_nodes: List[int] = None, -) -> StructureModel: +) -> DynamicStructureModel: """ Learn the graph structure of a Dynamic Bayesian Network describing conditional dependencies between variables in data. The input data is a time series or a list of realisations of a same time series. @@ -122,9 +123,9 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments tabu_child_nodes, ) - sm = StructureModel() - sm.add_nodes_from( - [f"{var}_lag{l_val}" for var in col_idx.keys() for l_val in range(p + 1)] + sm = DynamicStructureModel() + sm.add_nodes( + [DynamicStructureNode(var, l_val) for var in col_idx.keys() for l_val in range(p + 1)] ) sm.add_weighted_edges_from( [ @@ -166,7 +167,7 @@ def from_numpy_dynamic( # pylint: disable=too-many-arguments tabu_edges: List[Tuple[int, int, int]] = None, tabu_parent_nodes: List[int] = None, tabu_child_nodes: List[int] = None, -) -> StructureModel: +) -> DynamicStructureModel: """ Learn the graph structure of a Dynamic Bayesian Network describing conditional dependencies between variables in data. The input data is time series data present in numpy arrays X and Xlags. @@ -254,7 +255,7 @@ def from_numpy_dynamic( # pylint: disable=too-many-arguments def _matrices_to_structure_model( w_est: np.ndarray, a_est: np.ndarray -) -> StructureModel: +) -> DynamicStructureModel: """ Converts the matrices output by dynotears (W and A) into a StructureModel We use the following convention: @@ -268,13 +269,13 @@ def _matrices_to_structure_model( StructureModel representing the structure learnt """ - sm = StructureModel() + sm = DynamicStructureModel() lag_cols = [ - f"{var}_lag{l_val}" + DynamicStructureNode(var, l_val) for l_val in range(1 + (a_est.shape[0] // a_est.shape[1])) for var in range(a_est.shape[1]) ] - sm.add_nodes_from(lag_cols) + sm.add_nodes(lag_cols) sm.add_edges_from( [ (lag_cols[i], lag_cols[j], dict(weight=w_est[i, j])) diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 0343117..2b9af2a 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -31,12 +31,13 @@ ``StructureModel`` is a class that describes relationships between variables as a graph. """ -from typing import Any, Hashable, List, Set, Tuple, Union +from typing import Any, Hashable, List, Set, Tuple, Union, NamedTuple import networkx as nx import numpy as np from networkx.exception import NodeNotFound - +import inspect +import types def _validate_origin(origin: str) -> None: """ @@ -97,6 +98,7 @@ def __init__(self, incoming_graph_data=None, origin="unknown", **attr): super().__init__(incoming_graph_data, **attr) for u_of_edge, v_of_edge in self.edges: + print(f'in for loop in init {u_of_edge}, {v_of_edge}') self[u_of_edge][v_of_edge]["origin"] = origin def to_directed_class(self): @@ -333,3 +335,279 @@ def get_markov_blanket( ] ) return blanket + +class DynamicStructureNode(NamedTuple): + node: Union[int, str] + time_step: int + + @classmethod + def __instancecheck__(cls, instance): + print(f'inside instance check function {type(instance)}') + if hasattr(instance, 'node') and hasattr(instance, 'time_step') and hasattr(instance, 'get_node_name'): + return True + else: + return False + + def get_node_name(self): + return f'{self.node}_lag{self.time_step}' + + def __eq__(self, other): + if isinstance(other, DynamicStructureNode): + return self.get_node_name() == other.get_node_name() + return False + +def checkargs(function): + def _f(*arguments, **attr): + for index, argument in enumerate(inspect.getfullargspec(function)[0]): + if argument == 'self': + continue + try: + if isinstance(arguments[index], list): + for arg in arguments[index]: + if isinstance(arg, tuple) and not isinstance(arg, DynamicStructureNode): + if len(arg) == 3: + if not all(isinstance(n, DynamicStructureNode) for n in arg[:-1]): + raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) + else: + if not all(isinstance(n, DynamicStructureNode) for n in arg): + raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) + else: + if not isinstance(arg, function.__annotations__[argument].__args__[0]): + raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) + elif isinstance(arguments[index], types.GeneratorType): + # this comes from networkx, coerce into correct types + pass + elif not isinstance(arguments[index], function.__annotations__[argument]): + raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) + except IndexError as e: + # index error here means arg was passed implicitly + break + return function(*arguments, **attr) + _f.__doc__ = function.__doc__ + return _f + +def _validate_dsm_init_args(incoming_graph_data): + if isinstance(incoming_graph_data, list): + assert all(isinstance(n[0], DynamicStructureNode) and isinstance(n[1], DynamicStructureNode) for n in incoming_graph_data) + +class DynamicStructureModel(StructureModel): + """ + Base class for structure models, which are an extension of ``networkx.DiGraph``. + + A ``StructureModel`` stores nodes and edges with optional data, or attributes. + + Edges have one required attribute, "origin", which describes how the edge was created. + Origin can be one of either unknown, learned, or expert. + + StructureModel hold directed edges, describing a cause -> effect relationship. + Cycles are permitted within a ``StructureModel``. + + Nodes can be arbitrary (hashable) Python objects with optional key/value attributes. + By convention None is not used as a node. + + Edges are represented as links between nodes with optional key/value attributes. + """ + + def __init__(self, incoming_graph_data=None, origin="unknown", **attr): + """ + Create a ``StructureModel`` with incoming_graph_data, which has come from some origin. + + Args: + incoming_graph_data (Optional): input graph (optional, default: None) + Data to initialize graph. If None (default) an empty graph is created. + The data can be any format that is supported by the to_networkx_graph() + function, currently including edge list, dict of dicts, dict of lists, + NetworkX graph, NumPy matrix or 2d ndarray, SciPy sparse matrix, or PyGraphviz graph. + + origin (str): label for how the edges were created. Can be one of: + - unknown: edges exist for an unknown reason; + - learned: edges were created as the output of a machine-learning process; + - expert: edges were created by a domain expert. + + attr : Attributes to add to graph as key/value pairs (no attributes by default). + """ + if incoming_graph_data is not None: + _validate_dsm_init_args(incoming_graph_data) + super().__init__(incoming_graph_data, origin, **attr) + + + @checkargs + def add_node(self, dnode: DynamicStructureNode): + super().add_nodes_from([dnode.get_node_name()]) + + @checkargs + def add_nodes(self, dnodes: List[DynamicStructureNode]): + node_names = [dnode.get_node_name() for dnode in dnodes] + super().add_nodes_from(node_names) + + def to_directed_class(self): + """ + Returns the class to use for directed copies. + See :func:`networkx.DiGraph.to_directed()`. + """ + return DynamicStructureModel + + @checkargs + def get_target_subgraph(self, node: DynamicStructureNode) -> "DynamicStructureModel": + """ + Get the subgraph with the specified node. + + Args: + node: the name of the node. + + Returns: + The subgraph with the target node. + + Raises: + NodeNotFound: if the node is not found in the graph. + """ + node_name = node.get_node_name() + if node_name in self.nodes: + print(f'node {node} in self nodes {self.nodes}') + for component in nx.weakly_connected_components(self): + subgraph = self.subgraph(component).copy() + + if node_name in set(subgraph.nodes): + return subgraph + + raise NodeNotFound(f"Node {node} not found in the graph") + + @checkargs + def get_markov_blanket( + self, nodes: Union[DynamicStructureNode, List[DynamicStructureNode], Set[DynamicStructureNode]] + ) -> "DynamicStructureModel": + """ + Get Markov blanket of specified target nodes + + Args: + nodes: Target node name or list/set of target nodes + + Returns: + Markov blanket of the target node(s) + + Raises: + NodeNotFound: if one of the target nodes is not found in the graph. + """ + if not isinstance(nodes, (list, set)): + nodes = [nodes] + + blanket_nodes = set() + + for node in set(nodes): # Ensure target nodes are unique + if node not in set(self.nodes): + raise NodeNotFound(f"Node {node} not found in the graph.") + + blanket_nodes.add(node) + blanket_nodes.update(self.predecessors(node)) + + for child in self.successors(node): + blanket_nodes.add(child) + blanket_nodes.update(self.predecessors(child)) + + blanket = DynamicStructureModel() + blanket.add_nodes(blanket_nodes) + blanket.add_weighted_edges_from( + [ + (u, v, w) + for u, v, w in self.edges(data="weight") + if u in blanket_nodes and v in blanket_nodes + ] + ) + return blanket + + # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) + # this has been disabled because origin tracking is required for CausalGraphs + # implementing it in this way allows all 3rd party libraries and applications to + # integrate seamlessly, where edges will be given origin="unknown" where not provided + @checkargs + def add_edges_from( + self, + ebunch_to_add: Union[Set[Tuple[DynamicStructureNode, DynamicStructureNode]], List[Tuple[DynamicStructureNode, DynamicStructureNode]]], + origin: str = "unknown", + **attr, + ): + """ + Adds a bunch of causal relationships, u -> v. + + If u or v do not currently exists in the ``StructureModel`` then they will be created. + + By default relationships will be given origin="unknown", + but may also be given "learned" or "expert" origin. + + Notes: + Adding an edge that already exists will replace the existing edge. + See :func:`networkx.DiGraph.add_edges_from`. + + Args: + ebunch_to_add: container of edges. + Each edge given in the container will be added to the graph. + The edges must be given as 2-tuples (u, v) or + 3-tuples (u, v, d) where d is a dictionary containing edge data. + origin: label for how the edges were created. One of: + - unknown: edges exist for an unknown reason. + - learned: edges were created as the output of a machine-learning process. + - expert: edges were created by a domain expert. + **attr: Attributes to add to edge as key/value pairs (no attributes by default). + """ + _validate_origin(origin) + + if isinstance(ebunch_to_add, types.GeneratorType): + dsn_ebunch = [] + for e in ebunch_to_add: + if len(e) == 3: + dsn_ebunch.append((DynamicStructureNode(e[0][0], e[0][-1]).get_node_name(), DynamicStructureNode(e[1][0], e[1][-1]).get_node_name(), e[2])) + else: + dsn_ebunch.append((DynamicStructureNode(e[0][0], e[0][-1]).get_node_name(), DynamicStructureNode(e[1][0], e[1][-1]).get_node_name())) + else: + if len(ebunch_to_add[0]) == 3: + dsn_ebunch = [(e[0].get_node_name(), e[1].get_node_name(), e[2]) for e in ebunch_to_add] + else: + dsn_ebunch = [(e[0].get_node_name(), e[1].get_node_name()) for e in ebunch_to_add] + attr.update({"origin": origin}) + super().add_edges_from(dsn_ebunch, **attr) + + # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) + # this has been disabled because origin tracking is required for CausalGraphs + # implementing it in this way allows all 3rd party libraries and applications to + # integrate seamlessly, where edges will be given origin="unknown" where not provided + @checkargs + def add_weighted_edges_from( + self, + ebunch_to_add: Union[Set[Tuple[DynamicStructureNode, DynamicStructureNode, float]], List[Tuple[DynamicStructureNode, DynamicStructureNode, float]]], + weight: str = "weight", + origin: str = "unknown", + **attr + ): + """ + Adds a bunch of weighted causal relationships, u -> v. + + If u or v do not currently exists in the ``StructureModel`` then they will be created. + + By default relationships will be given origin="unknown", + but may also be given "learned" or "expert" origin. + + Notes: + Adding an edge that already exists will replace the existing edge. + See :func:`networkx.DiGraph.add_edges_from`. + + Args: + ebunch_to_add: container of edges. + Each edge given in the container will be added to the graph. + The edges must be given as 2-tuples (u, v) or + 3-tuples (u, v, d) where d is a dictionary containing edge data. + weight : string, optional (default='weight'). + The attribute name for the edge weights to be added. + origin: label for how the edges were created. One of: + - unknown: edges exist for an unknown reason; + - learned: edges were created as the output of a machine-learning process; + - expert: edges were created by a domain expert. + **attr: Attributes to add to edge as key/value pairs (no attributes by default). + """ + _validate_origin(origin) + + if isinstance(ebunch_to_add, types.GeneratorType): + dsn_ebunch = [(DynamicStructureNode(e[0][0], e[0][-1]).get_node_name(), DynamicStructureNode(e[1][0], e[1][-1]).get_node_name(), e[2]) for e in ebunch_to_add] + else: + dsn_ebunch = [(e[0].get_node_name(), e[1].get_node_name(), e[2]) for e in ebunch_to_add] + attr.update({"origin": origin}) + super().add_weighted_edges_from(dsn_ebunch, weight=weight, **attr) \ No newline at end of file diff --git a/tests/test_dynamicbayesiannetwork.py b/tests/test_dynamicbayesiannetwork.py index 7e0ed1e..a4ac757 100644 --- a/tests/test_dynamicbayesiannetwork.py +++ b/tests/test_dynamicbayesiannetwork.py @@ -26,900 +26,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - -import numpy as np -import pandas as pd -import pytest - -from causalnex.evaluation import classification_report, roc_auc -from causalnex.inference import InferenceEngine -from causalnex.network import DynamicBayesianNetwork -from causalnex.structure import StructureModel -from causalnex.structure.notears import from_pandas - -from .estimator.test_em import naive_bayes_plus_parents - - -class TestFitNodeStates: - """Test behaviour of fit node states method""" - - @pytest.mark.parametrize( - "weighted_edges, data", - [ - ([("a", "b", 1)], pd.DataFrame([[1, 1]], columns=["a", "b"])), - ( - [("a", "b", 1)], - pd.DataFrame([[1, 1, 1, 1]], columns=["a", "b", "c", "d"]), - ), - # c and d are isolated nodes in the data - ], - ) - def test_all_nodes_included(self, weighted_edges, data): - """No errors if all the nodes can be found in the columns of training data""" - cg = StructureModel() - cg.add_weighted_edges_from(weighted_edges) - bn = DynamicBayesianNetwork(cg).fit_node_states(data) - assert all(node in data.columns for node in bn.node_states.keys()) - - def test_all_states_included(self): - """All states in a node should be included""" - cg = StructureModel() - cg.add_weighted_edges_from([("a", "b", 1)]) - bn = DynamicBayesianNetwork(cg).fit_node_states( - pd.DataFrame([[i, i] for i in range(10)], columns=["a", "b"]) - ) - assert all(v in bn.node_states["a"] for v in range(10)) - - def test_fit_with_null_states_raises_error(self): - """An error should be raised if fit is called with null data""" - cg = StructureModel() - cg.add_weighted_edges_from([("a", "b", 1)]) - with pytest.raises(ValueError, match="node '.*' contains None state"): - DynamicBayesianNetwork(cg).fit_node_states( - pd.DataFrame([[None, 1]], columns=["a", "b"]) - ) - - def test_fit_with_missing_feature_in_data(self): - """An error should be raised if fit is called with missing feature in data""" - cg = StructureModel() - - cg.add_weighted_edges_from([("a", "e", 1)]) - with pytest.raises( - KeyError, - match="The data does not cover all the features found in the Bayesian Network. " - "Please check the following features: {'e'}", - ): - DynamicBayesianNetwork(cg).fit_node_states( - pd.DataFrame([[1, 1, 1, 1]], columns=["a", "b", "c", "d"]) - ) - - -class TestFitCPDSErrors: - """Test errors for fit CPDs method""" - - def test_invalid_method(self, bn, train_data_discrete): - """a value error should be raised in an invalid method is provided""" - - with pytest.raises(ValueError, match=r"unrecognised method.*"): - bn.fit_cpds(train_data_discrete, method="INVALID") - - def test_invalid_prior(self, bn, train_data_discrete): - """a value error should be raised in an invalid prior is provided""" - - with pytest.raises(ValueError, match=r"unrecognised bayes_prior.*"): - bn.fit_cpds( - train_data_discrete, method="BayesianEstimator", bayes_prior="INVALID" - ) - - -class TestFitCPDsMaximumLikelihoodEstimator: - """Test behaviour of fit_cpds using MLE""" - - def test_cause_only_node(self, bn, train_data_discrete, train_data_discrete_cpds): - """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" - - bn.fit_cpds(train_data_discrete) - cpds = bn.cpds - - assert ( - np.mean( - np.abs( - cpds["d"].values.reshape(2) - - train_data_discrete_cpds["d"].reshape(2) - ) - ) - < 1e-7 - ) - assert ( - np.mean( - np.abs( - cpds["e"].values.reshape(2) - - train_data_discrete_cpds["e"].reshape(2) - ) - ) - < 1e-7 - ) - - def test_dependent_node(self, bn, train_data_discrete, train_data_discrete_cpds): - """Test that probabilities are fit correctly to nodes that are caused by other nodes""" - - bn.fit_cpds(train_data_discrete) - cpds = bn.cpds - - assert ( - np.mean( - np.abs( - cpds["a"].values.reshape(24) - - train_data_discrete_cpds["a"].reshape(24) - ) - ) - < 1e-7 - ) - assert ( - np.mean( - np.abs( - cpds["b"].values.reshape(12) - - train_data_discrete_cpds["b"].reshape(12) - ) - ) - < 1e-7 - ) - assert ( - np.mean( - np.abs( - cpds["c"].values.reshape(60) - - train_data_discrete_cpds["c"].reshape(60) - ) - ) - < 1e-7 - ) - - def test_fit_missing_states(self): - """test issues/15: should be possible to fit with missing states""" - - sm = StructureModel([("a", "b"), ("c", "b")]) - bn = DynamicBayesianNetwork(sm) - - train = pd.DataFrame( - data=[[0, 0, 1], [1, 0, 1], [1, 1, 1]], columns=["a", "b", "c"] - ) - test = pd.DataFrame( - data=[[0, 0, 1], [1, 0, 1], [1, 1, 2]], columns=["a", "b", "c"] - ) - data = pd.concat([train, test]) - - bn.fit_node_states(data) - bn.fit_cpds(train) - - assert bn.cpds["c"].loc[1][0] == 1 - assert bn.cpds["c"].loc[2][0] == 0 - - -class TestFitBayesianEstimator: - """Test behaviour of fit_cpds using BE""" - - def test_cause_only_node_bdeu( - self, bn, train_data_discrete, train_data_discrete_cpds - ): - """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" - - bn.fit_cpds( - train_data_discrete, - method="BayesianEstimator", - bayes_prior="BDeu", - equivalent_sample_size=5, - ) - cpds = bn.cpds - - assert ( - np.mean( - np.abs( - cpds["d"].values.reshape(2) - - train_data_discrete_cpds["d"].reshape(2) - ) - ) - < 0.02 - ) - assert ( - np.mean( - np.abs( - cpds["e"].values.reshape(2) - - train_data_discrete_cpds["e"].reshape(2) - ) - ) - < 0.02 - ) - - def test_cause_only_node_k2( - self, bn, train_data_discrete, train_data_discrete_cpds - ): - """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" - - bn.fit_cpds(train_data_discrete, method="BayesianEstimator", bayes_prior="K2") - cpds = bn.cpds - - assert ( - np.mean( - np.abs( - cpds["d"].values.reshape(2) - - train_data_discrete_cpds["d"].reshape(2) - ) - ) - < 0.02 - ) - assert ( - np.mean( - np.abs( - cpds["e"].values.reshape(2) - - train_data_discrete_cpds["e"].reshape(2) - ) - ) - < 0.02 - ) - - def test_dependent_node_bdeu( - self, bn, train_data_discrete, train_data_discrete_cpds - ): - """Test that probabilities are fit correctly to nodes that are caused by other nodes""" - - bn.fit_cpds( - train_data_discrete, - method="BayesianEstimator", - bayes_prior="BDeu", - equivalent_sample_size=1, - ) - cpds = bn.cpds - - assert ( - np.mean( - np.abs( - cpds["a"].values.reshape(24) - - train_data_discrete_cpds["a"].reshape(24) - ) - ) - < 0.02 - ) - assert ( - np.mean( - np.abs( - cpds["b"].values.reshape(12) - - train_data_discrete_cpds["b"].reshape(12) - ) - ) - < 0.02 - ) - assert ( - np.mean( - np.abs( - cpds["c"].values.reshape(60) - - train_data_discrete_cpds["c"].reshape(60) - ) - ) - < 0.02 - ) - - def test_dependent_node_k2( - self, bn, train_data_discrete, train_data_discrete_cpds_k2 - ): - """Test that probabilities are fit correctly to nodes that are caused by other nodes""" - - bn.fit_cpds(train_data_discrete, method="BayesianEstimator", bayes_prior="K2") - cpds = bn.cpds - - assert ( - np.mean( - np.abs( - cpds["a"].values.reshape(24) - - train_data_discrete_cpds_k2["a"].reshape(24) - ) - ) - < 1e-7 - ) - assert ( - np.mean( - np.abs( - cpds["b"].values.reshape(12) - - train_data_discrete_cpds_k2["b"].reshape(12) - ) - ) - < 1e-7 - ) - assert ( - np.mean( - np.abs( - cpds["c"].values.reshape(60) - - train_data_discrete_cpds_k2["c"].reshape(60) - ) - ) - < 1e-7 - ) - - def test_fit_missing_states(self): - """test issues/15: should be possible to fit with missing states""" - - sm = StructureModel([("a", "b"), ("c", "b")]) - bn = DynamicBayesianNetwork(sm) - - train = pd.DataFrame( - data=[[0, 0, 1], [1, 0, 1], [1, 1, 1]], columns=["a", "b", "c"] - ) - test = pd.DataFrame( - data=[[0, 0, 1], [1, 0, 1], [1, 1, 2]], columns=["a", "b", "c"] - ) - data = pd.concat([train, test]) - - bn.fit_node_states(data) - bn.fit_cpds(train, method="BayesianEstimator", bayes_prior="K2") - - assert bn.cpds["c"].loc[1][0] == 0.8 - assert bn.cpds["c"].loc[2][0] == 0.2 - - -class TestPredictMaximumLikelihoodEstimator: - """Test behaviour of predict using MLE""" - - def test_predictions_are_based_on_probabilities( - self, bn, train_data_discrete, test_data_c_discrete - ): - """Predictions made using the model should be based on the probabilities that are in the model""" - - bn.fit_cpds(train_data_discrete) - predictions = bn.predict(test_data_c_discrete, "c") - assert np.all( - predictions.values.reshape(len(predictions.values)) - == test_data_c_discrete["c"].values - ) - - def test_prediction_node_suffixed_as_prediction( - self, bn, train_data_discrete, test_data_c_discrete - ): - """The column that contains the values of the predicted node should be named node_prediction""" - - bn.fit_cpds(train_data_discrete) - predictions = bn.predict(test_data_c_discrete, "c") - assert "c_prediction" in predictions.columns - - def test_only_predicted_column_returned( - self, bn, train_data_discrete, test_data_c_discrete - ): - """The returned df should not contain any of the input data columns""" - - bn.fit_cpds(train_data_discrete) - predictions = bn.predict(test_data_c_discrete, "c") - assert len(predictions.columns) == 1 - - def test_predictions_are_not_appended_to_input_df( - self, bn, train_data_discrete, test_data_c_discrete - ): - """The predictions should not be appended to the input df""" - - expected_cols = test_data_c_discrete.columns - bn.fit_cpds(train_data_discrete) - bn.predict(test_data_c_discrete, "c") - assert np.array_equal(test_data_c_discrete.columns, expected_cols) - - def test_missing_parent(self, bn, train_data_discrete, test_data_c_discrete): - """Predictions made when parents are missing should still be reasonably accurate""" - - bn.fit_cpds(train_data_discrete) - predictions = bn.predict(test_data_c_discrete[["a", "b", "c", "d"]], "c") - - n = len(test_data_c_discrete) - - accuracy = ( - 1 - - np.count_nonzero( - predictions.values.reshape(len(predictions.values)) - - test_data_c_discrete["c"].values - ) - / n - ) - - assert accuracy > 0.9 - - def test_missing_non_parent(self, bn, train_data_discrete, test_data_c_discrete): - """It should be possible to make predictions with non-parent nodes missing""" - - bn.fit_cpds(train_data_discrete) - predictions = bn.predict(test_data_c_discrete[["b", "c", "d", "e"]], "c") - assert np.all( - predictions.values.reshape(len(predictions.values)) - == test_data_c_discrete["c"].values - ) - - -class TestPredictBayesianEstimator: - """Test behaviour of predict using BE""" - - def test_predictions_are_based_on_probabilities_dbeu( - self, bn, train_data_discrete, test_data_c_discrete - ): - """Predictions made using the model should be based on the probabilities that are in the model""" - - bn.fit_cpds( - train_data_discrete, - method="BayesianEstimator", - bayes_prior="BDeu", - equivalent_sample_size=5, - ) - predictions = bn.predict(test_data_c_discrete, "c") - assert np.all( - predictions.values.reshape(len(predictions.values)) - == test_data_c_discrete["c"].values - ) - - def test_predictions_are_based_on_probabilities_k2( - self, bn, train_data_discrete, test_data_c_discrete - ): - """Predictions made using the model should be based on the probabilities that are in the model""" - - bn.fit_cpds( - train_data_discrete, - method="BayesianEstimator", - bayes_prior="K2", - equivalent_sample_size=5, - ) - predictions = bn.predict(test_data_c_discrete, "c") - assert np.all( - predictions.values.reshape(len(predictions.values)) - == test_data_c_discrete["c"].values - ) - - -class TestPredictProbabilityMaximumLikelihoodEstimator: - """Test behaviour of predict_probability using MLE""" - - def test_expected_probabilities_are_predicted( - self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood - ): - """Probabilities should return exactly correct on a hand computable scenario""" - - bn.fit_cpds(train_data_discrete) - probability = bn.predict_probability(test_data_c_discrete, "c") - - assert all( - np.isclose( - probability.values.flatten(), test_data_c_likelihood.values.flatten() - ) - ) - - def test_missing_parent( - self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood - ): - """Probabilities made when parents are missing should still be reasonably accurate""" - - bn.fit_cpds(train_data_discrete) - probability = bn.predict_probability( - test_data_c_discrete[["a", "b", "c", "d"]], "c" - ) - - n = len(probability.values.flatten()) - - accuracy = ( - np.count_nonzero( - [ - 1 if math.isclose(a, b, abs_tol=0.15) else 0 - for a, b in zip( - probability.values.flatten(), - test_data_c_likelihood.values.flatten(), - ) - ] - ) - / n - ) - - assert accuracy > 0.8 - - def test_missing_non_parent( - self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood - ): - """It should be possible to make predictions with non-parent nodes missing""" - - bn.fit_cpds(train_data_discrete) - probability = bn.predict_probability( - test_data_c_discrete[["b", "c", "d", "e"]], "c" - ) - assert all( - np.isclose( - probability.values.flatten(), test_data_c_likelihood.values.flatten() - ) - ) - - -class TestPredictProbabilityBayesianEstimator: - """Test behaviour of predict_probability using BayesianEstimator""" - - def test_expected_probabilities_are_predicted( - self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood - ): - """Probabilities should return exactly correct on a hand computable scenario""" - - bn.fit_cpds( - train_data_discrete, - method="BayesianEstimator", - bayes_prior="BDeu", - equivalent_sample_size=1, - ) - probability = bn.predict_probability(test_data_c_discrete, "c") - assert all( - np.isclose( - probability.values.flatten(), - test_data_c_likelihood.values.flatten(), - atol=0.1, - ) - ) - - -class TestFitNodesStatesAndCPDs: - """Test behaviour of helper function""" - - def test_behaves_same_as_separate_calls(self, train_data_idx, train_data_discrete): - bn1 = DynamicBayesianNetwork(from_pandas(train_data_idx, w_threshold=0.3)) - bn2 = DynamicBayesianNetwork(from_pandas(train_data_idx, w_threshold=0.3)) - - bn1.fit_node_states(train_data_discrete).fit_cpds(train_data_discrete) - bn2.fit_node_states_and_cpds(train_data_discrete) - - assert bn1.edges == bn2.edges - assert bn1.node_states == bn2.node_states - - cpds1 = bn1.cpds - cpds2 = bn2.cpds - - assert cpds1.keys() == cpds2.keys() - - for k, df in cpds1.items(): - assert df.equals(cpds2[k]) - - -class TestLatentVariable: - @staticmethod - def mean_absolute_error(cpds_a, cpds_b): - """Compute the absolute error among each single parameter and average them out""" - - mae = 0 - n_param = 0 - - for node in cpds_a.keys(): - err = np.abs(cpds_a[node] - cpds_b[node]).values - mae += np.sum(err) - n_param += err.shape[0] * err.shape[1] - - return mae / n_param - - def test_em_algorithm(self): # pylint: disable=too-many-locals - """ - Test if `DynamicBayesianNetwork` works with EM algorithm. - We use a naive bayes + parents + an extra node not related to the latent variable. - """ - - # p0 p1 p2 - # \ | / - # z - # / | \ - # c0 c1 c2 - # | - # cc0 - np.random.seed(22) - - data, sm, _, true_lv_values = naive_bayes_plus_parents( - percentage_not_missing=0.1, - samples=1000, - p_z=0.7, - p_c=0.7, - ) - data["cc_0"] = np.where( - np.random.random(len(data)) < 0.5, data["c_0"], (data["c_0"] + 1) % 3 - ) - data.drop(columns=["z"], inplace=True) - - complete_data = data.copy(deep=True) - complete_data["z"] = true_lv_values - - # Baseline model: the structure of the figure trained with complete data. We try to reproduce it - complete_bn = DynamicBayesianNetwork( - StructureModel(list(sm.edges) + [("c_0", "cc_0")]) - ) - complete_bn.fit_node_states_and_cpds(complete_data) - - # BN without latent variable: All `p`s are connected to all `c`s + `c0` ->`cc0` - sm_no_lv = StructureModel( - [(f"p_{p}", f"c_{c}") for p in range(3) for c in range(3)] - + [("c_0", "cc_0")] - ) - bn = DynamicBayesianNetwork(sm_no_lv) - bn.fit_node_states(data) - bn.fit_cpds(data) - - # TEST 1: cc_0 does not depend on the latent variable so: - assert np.all(bn.cpds["cc_0"] == complete_bn.cpds["cc_0"]) - - # BN with latent variable - # When we add the latent variable, we add the edges in the image above - # and remove the connection among `p`s and `c`s - edges_to_add = list(sm.edges) - edges_to_remove = [(f"p_{p}", f"c_{c}") for p in range(3) for c in range(3)] - bn.add_node("z", edges_to_add, edges_to_remove) - bn.fit_latent_cpds("z", [0, 1, 2], data, stopping_delta=0.001) - - # TEST 2: cc_0 CPD should remain untouched by the EM algorithm - assert np.all(bn.cpds["cc_0"] == complete_bn.cpds["cc_0"]) - - # TEST 3: We should recover the correct CPDs quite accurately - assert bn.cpds.keys() == complete_bn.cpds.keys() - assert self.mean_absolute_error(bn.cpds, complete_bn.cpds) < 0.01 - - # TEST 4: Inference over recovered CPDs should be also accurate - eng = InferenceEngine(bn) - query = eng.query() - n_rows = complete_data.shape[0] - - for node in query: - assert ( - np.abs(query[node][0] - sum(complete_data[node] == 0) / n_rows) < 1e-2 - ) - assert ( - np.abs(query[node][1] - sum(complete_data[node] == 1) / n_rows) < 1e-2 - ) - - # TEST 5: Inference using predict and predict_probability functions - report = classification_report(bn, complete_data, "z") - _, auc = roc_auc(bn, complete_data, "z") - complete_report = classification_report(complete_bn, complete_data, "z") - _, complete_auc = roc_auc(complete_bn, complete_data, "z") - - for category, metrics in report.items(): - if isinstance(metrics, dict): - for key, val in metrics.items(): - assert np.abs(val - complete_report[category][key]) < 1e-2 - else: - assert np.abs(metrics - complete_report[category]) < 1e-2 - - assert np.abs(auc - complete_auc) < 1e-2 - - -class TestAddNode: - def test_add_node_not_in_edges_to_add(self): - """An error should be raised if the latent variable is NOT part of the edges to add""" - - with pytest.raises( - ValueError, - match="Should only add edges containing node 'd'", - ): - _, sm, _, _ = naive_bayes_plus_parents() - sm = StructureModel(list(sm.edges)) - bn = DynamicBayesianNetwork(sm) - bn.add_node("d", [("a", "z"), ("b", "z")], []) - - def test_add_node_in_edges_to_remove(self): - """An error should be raised if the latent variable is part of the edges to remove""" - - with pytest.raises( - ValueError, - match="Should only remove edges NOT containing node 'd'", - ): - _, sm, _, _ = naive_bayes_plus_parents() - sm = StructureModel(list(sm.edges)) - bn = DynamicBayesianNetwork(sm) - bn.add_node("d", [], [("a", "d"), ("b", "d")]) - - -class TestFitLatentCPDs: - @pytest.mark.parametrize("lv_name", [None, [], set(), {}, tuple(), 123, {}]) - def test_fit_invalid_lv_name(self, lv_name): - """An error should be raised if the latent variable is of an invalid type""" - - with pytest.raises( - ValueError, - match=r"Invalid latent variable name *", - ): - df, sm, _, _ = naive_bayes_plus_parents() - sm = StructureModel(list(sm.edges)) - bn = DynamicBayesianNetwork(sm) - bn.fit_latent_cpds(lv_name, [0, 1, 2], df) - - def test_fit_lv_not_added(self): - """An error should be raised if the latent variable is not added to the network yet""" - - with pytest.raises( - ValueError, - match=r"Latent variable 'd' not added to the network", - ): - df, sm, _, _ = naive_bayes_plus_parents() - sm = StructureModel(list(sm.edges)) - bn = DynamicBayesianNetwork(sm) - bn.fit_latent_cpds("d", [0, 1, 2], df) - - @pytest.mark.parametrize("lv_states", [None, [], set(), {}]) - def test_fit_invalid_lv_states(self, lv_states): - """An error should be raised if the latent variable has invalid states""" - - with pytest.raises( - ValueError, - match="Latent variable 'd' contains no states", - ): - df, sm, _, _ = naive_bayes_plus_parents() - sm = StructureModel(list(sm.edges)) - bn = DynamicBayesianNetwork(sm) - bn.add_node("d", [("z", "d")], []) - bn.fit_latent_cpds("d", lv_states, df) - - -class TestSetCPD: - """Test behaviour of adding a self-defined cpd""" - - def test_set_cpd(self, bn, good_cpd): - """The CPD of the target node should be the same as the self-defined table after adding""" - - bn.set_cpd("b", good_cpd) - assert bn.cpds["b"].values.tolist() == good_cpd.values.tolist() - - def test_set_other_cpd(self, bn, good_cpd): - """The CPD of nodes other than the target node should not be affected""" - - cpd = bn.cpds["a"].values.tolist() - bn.set_cpd("b", good_cpd) - cpd_after_adding = bn.cpds["a"].values.tolist() - - assert all( - val == val_after_adding - for val, val_after_adding in zip(*(cpd, cpd_after_adding)) - ) - - def test_set_cpd_to_non_existent_node(self, bn, good_cpd): - """Should raise error if adding a cpd to a non-existing node in Bayesian Network""" - - with pytest.raises( - ValueError, - match=r'Non-existing node "test"', - ): - bn.set_cpd("test", good_cpd) - - def test_set_bad_cpd(self, bn, bad_cpd): - """Should raise error if it the prpbability values do not sum up to 1 in the table""" - - with pytest.raises( - ValueError, - match=r"Sum or integral of conditional probabilites for node b is not equal to 1.", - ): - bn.set_cpd("b", bad_cpd) - - def test_no_overwritten_after_setting_bad_cpd(self, bn, bad_cpd): - """The cpd of bn won't be overwritten if adding a bad cpd""" - - original_cpd = bn.cpds["b"].values.tolist() - - try: - bn.set_cpd("b", bad_cpd) - except ValueError: - assert bn.cpds["b"].values.tolist() == original_cpd - - def test_bad_node_index(self, bn, good_cpd): - """Should raise an error when setting bad node index""" - - bad_cpd = good_cpd - bad_cpd.index.name = "test" - - with pytest.raises( - IndexError, - match=r"Wrong index values. Please check your indices", - ): - bn.set_cpd("b", bad_cpd) - - def test_bad_node_states_index(self, bn, good_cpd): - """Should raise an error when setting bad node states index""" - - bad_cpd = good_cpd.reindex([1, 2, 3]) - - with pytest.raises( - IndexError, - match=r"Wrong index values. Please check your indices", - ): - bn.set_cpd("b", bad_cpd) - - def test_bad_parent_node_index(self, bn, good_cpd): - """Should raise an error when setting bad parent node index""" - - bad_cpd = good_cpd - bad_cpd.columns = bad_cpd.columns.rename("test", level=1) - - with pytest.raises( - IndexError, - match=r"Wrong index values. Please check your indices", - ): - bn.set_cpd("b", bad_cpd) - - def test_bad_parent_node_states_index(self, bn, good_cpd): - """Should raise an error when setting bad parent node states index""" - - bad_cpd = good_cpd - bad_cpd.columns.set_levels(["test1", "test2"], level=0, inplace=True) - - with pytest.raises( - IndexError, - match=r"Wrong index values. Please check your indices", - ): - bn.set_cpd("b", bad_cpd) - - -class TestCPDsProperty: - """Test behaviour of the CPDs property""" - - def test_row_index_of_state_values(self, bn): - """CPDs should have row index set to values of all possible states of the node""" - - assert bn.cpds["a"].index.tolist() == sorted(list(bn.node_states["a"])) - - def test_col_index_of_parent_state_combinations(self, bn): - """CPDs should have a column multi-index of parent state permutations""" - - assert bn.cpds["a"].columns.names == ["b", "d"] - - -class TestInit: - """Test behaviour when constructing a DynamicBayesianNetwork""" - - def test_cycles_in_structure(self): - """An error should be raised if cycles are present""" - - with pytest.raises( - ValueError, - match=r"The given structure is not acyclic\. " - r"Please review the following cycle\.*", - ): - DynamicBayesianNetwork(StructureModel([(0, 1), (1, 2), (2, 0)])) - - @pytest.mark.parametrize( - "test_input,n_components", - [([(0, 1), (1, 2), (3, 4), (4, 6)], 2), ([(0, 1), (1, 2), (3, 4), (5, 6)], 3)], - ) - def test_disconnected_components(self, test_input, n_components): - """An error should be raised if there is more than one graph component""" - - with pytest.raises( - ValueError, - match=r"The given structure has " - + str(n_components) - + r" separated graph components\. " - r"Please make sure it has only one\.", - ): - DynamicBayesianNetwork(StructureModel(test_input)) - - -class TestStructure: - """Test behaviour of the property structure""" - - def test_get_structure(self): - """The structure retrieved should be the same""" - - sm = StructureModel() - sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") - sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") - sm.add_weighted_edges_from([(3, 5, 0.7)], origin="expert") - - bn = DynamicBayesianNetwork(sm) - - sm_from_bn = bn.structure - - assert set(sm.edges.data("origin")) == set(sm_from_bn.edges.data("origin")) - assert set(sm.edges.data("weight")) == set(sm_from_bn.edges.data("weight")) - assert set(sm.nodes) == set(sm_from_bn.nodes) - - def test_set_structure(self): - """An error should be raised if setting the structure""" - - sm = StructureModel() - sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") - sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") - sm.add_weighted_edges_from([(3, 5, 0.7)], origin="expert") - - bn = DynamicBayesianNetwork(sm) - - new_sm = StructureModel() - sm.add_weighted_edges_from([(2, 5, 3.0)], origin="unknown") - sm.add_weighted_edges_from([(2, 3, 2.0)], origin="learned") - sm.add_weighted_edges_from([(3, 4, 1.7)], origin="expert") - - with pytest.raises(AttributeError, match=r"can't set attribute"): - bn.structure = new_sm +#from causalnex.structure.dynotears import from_numpy_dynamic, from_pandas_dynamic +#from causalnex.network import DynamicBayesianNetwork + +''' +functions to test in DBN are fit_node_states, fit_node_states_and_cpds, fit_latent_cpds, predict, predict_probability +only change to dynotears is using DSM instead of SM +in DSM, functionality is still provided by nx.DiGraph +main update will be to update DBN to use different model than pgmpy.models.BayesianModel +just run regression tests on test_dynotears and test_dynamicstructure_model +''' \ No newline at end of file From 15e23cb387edc73ac3712859038f1cc6d6c2fb71 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Thu, 29 Sep 2022 19:41:30 -0500 Subject: [PATCH 03/13] push tests for dynamic structure model --- tests/structure/test_dynamicstructuremodel.py | 818 ++++++++++++++++++ 1 file changed, 818 insertions(+) create mode 100644 tests/structure/test_dynamicstructuremodel.py diff --git a/tests/structure/test_dynamicstructuremodel.py b/tests/structure/test_dynamicstructuremodel.py new file mode 100644 index 0000000..b7f7741 --- /dev/null +++ b/tests/structure/test_dynamicstructuremodel.py @@ -0,0 +1,818 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from networkx.exception import NodeNotFound + +from causalnex.structure import DynamicStructureModel, DynamicStructureNode +import re + +class TestDynamicStructureModel: + def test_init_has_origin(self): + """Creating a DynamicStructureModel using constructor should give all edges unknown origin""" + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] + sm = DynamicStructureModel([(nodes[0], nodes[1])]) + assert (nodes[0].get_node_name(), nodes[1].get_node_name()) in sm.edges + assert (nodes[0].get_node_name(), nodes[1].get_node_name(), "unknown") in sm.edges.data("origin") + + def test_init_with_origin(self): + """should be possible to specify origin during init""" + + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] + sm = DynamicStructureModel([(nodes[0], nodes[1])], origin="learned") + assert (nodes[0].get_node_name(), nodes[1].get_node_name(), "learned") in sm.edges.data("origin") + + def test_edge_unknown_property(self): + """should return only edges whose origin is unknown""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, origin="unknown") + sm.add_edge(1, 3, origin="learned") + sm.add_edge(1, 4, origin="expert") + + assert sm.edges_with_origin("unknown") == [(1, 2)] + + def test_edge_learned_property(self): + """should return only edges whose origin is unknown""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, origin="unknown") + sm.add_edge(1, 3, origin="learned") + sm.add_edge(1, 4, origin="expert") + + assert sm.edges_with_origin("learned") == [(1, 3)] + + def test_edge_expert_property(self): + """should return only edges whose origin is unknown""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, origin="unknown") + sm.add_edge(1, 3, origin="learned") + sm.add_edge(1, 4, origin="expert") + + assert sm.edges_with_origin("expert") == [(1, 4)] + + def test_to_directed(self): + """should create a structure model""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)] + + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[0]), (nodes[1], nodes[2]), (nodes[2], nodes[3])] + sm.add_edges_from(edges) + + dag = sm.to_directed() + assert isinstance(dag, DynamicStructureModel) + assert all((edge[0].get_node_name(), edge[1].get_node_name()) in dag.edges for edge in edges) + + def test_to_undirected(self): + """should create an undirected Graph""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)] + + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[0]), (nodes[1], nodes[2]), (nodes[2], nodes[3])] + sm.add_edges_from(edges) + + udg = sm.to_undirected() + print(f'udg edges {udg.edges}') + assert all((edge[0].get_node_name(), edge[1].get_node_name()) in udg.edges for edge in [(nodes[1], nodes[2]), (nodes[2], nodes[3])]) + assert (nodes[0].get_node_name(), nodes[1].get_node_name()) in udg.edges or (nodes[1].get_node_name(), nodes[0].get_node_name()) in udg.edges + assert len(udg.edges) == 3 + + +class TestDynamicStructureModelAddEdge: + def test_add_edge_default(self): + """edges added with default origin should be identified as unknown origin""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2) + + assert (1, 2) in sm.edges + assert (1, 2, "unknown") in sm.edges.data("origin") + + def test_add_edge_unknown(self): + """edges added with unknown origin should be labelled as unknown origin""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, "unknown") + + assert (1, 2) in sm.edges + assert (1, 2, "unknown") in sm.edges.data("origin") + + def test_add_edge_learned(self): + """edges added with learned origin should be labelled as learned origin""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, "learned") + + assert (1, 2) in sm.edges + assert (1, 2, "learned") in sm.edges.data("origin") + + def test_add_edge_expert(self): + """edges added with expert origin should be labelled as expert origin""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, "expert") + + assert (1, 2) in sm.edges + assert (1, 2, "expert") in sm.edges.data("origin") + + def test_add_edge_other(self): + """edges added with other origin should throw an error""" + + sm = DynamicStructureModel() + + with pytest.raises(ValueError, match="^Unknown origin: must be one of.*$"): + sm.add_edge(1, 2, "other") + + def test_add_edge_custom_attr(self): + """it should be possible to add an edge with custom attributes""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, x="Y") + + assert (1, 2) in sm.edges + assert (1, 2, "Y") in sm.edges.data("x") + + def test_add_edge_multiple_times(self): + """adding an edge again should update the edges origin attr""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, origin="unknown") + assert (1, 2, "unknown") in sm.edges.data("origin") + sm.add_edge(1, 2, origin="learned") + assert (1, 2, "learned") in sm.edges.data("origin") + + def test_add_multiple_edges(self): + """it should be possible to add multiple edges with different origins""" + + sm = DynamicStructureModel() + sm.add_edge(1, 2, origin="unknown") + sm.add_edge(1, 3, origin="learned") + sm.add_edge(1, 4, origin="expert") + + assert (1, 2, "unknown") in sm.edges.data("origin") + assert (1, 3, "learned") in sm.edges.data("origin") + assert (1, 4, "expert") in sm.edges.data("origin") + + +class TestDynamicStructureModelAddEdgesFrom: + def test_add_edges_from_default(self): + """edges added with default origin should be identified as unknown origin""" + print('******************* hello **************************') + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] + sm.add_edges_from(edges) + assert all((edge[0].get_node_name(), edge[1].get_node_name()) in sm.edges for edge in edges) + assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v in edges) + + def test_add_edges_from_unknown(self): + """edges added with unknown origin should be labelled as unknown origin""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] + sm.add_edges_from(edges, "unknown") + + assert all((u.get_node_name(), v.get_node_name()) in sm.edges for u, v in edges) + assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v in edges) + + def test_add_edges_from_learned(self): + """edges added with learned origin should be labelled as learned origin""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] + sm.add_edges_from(edges, "learned") + + assert all((u.get_node_name(), v.get_node_name()) in sm.edges for u, v in edges) + assert all((u.get_node_name(), v.get_node_name(), "learned") in sm.edges.data("origin") for u, v in edges) + + def test_add_edges_from_expert(self): + """edges added with expert origin should be labelled as expert origin""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] + sm.add_edges_from(edges, "expert") + + assert all((u.get_node_name(), v.get_node_name()) in sm.edges for u, v in edges) + assert all((u.get_node_name(), v.get_node_name(), "expert") in sm.edges.data("origin") for u, v in edges) + + def test_add_edges_from_other(self): + """edges added with other origin should throw an error""" + + sm = DynamicStructureModel() + + with pytest.raises(ValueError, match="^Unknown origin: must be one of.*$"): + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] + sm.add_edges_from([(nodes[0], nodes[1])], "other") + + def test_add_edges_from_custom_attr(self): + """it should be possible to add edges with custom attributes""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] + sm.add_edges_from(edges, x="Y") + + assert all((u.get_node_name(), v.get_node_name()) in sm.edges for u, v in edges) + assert all((u.get_node_name(), v.get_node_name(), "Y") in sm.edges.data("x") for u, v in edges) + + def test_add_edges_from_multiple_times(self): + """adding edges again should update the edges origin attr""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] + sm.add_edges_from(edges, "unknown") + assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v in edges) + sm.add_edges_from(edges, "learned") + assert all((u.get_node_name(), v.get_node_name(), "learned") in sm.edges.data("origin") for u, v in edges) + + def test_add_multiple_edges(self): + """it should be possible to add multiple edges with different origins""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)] + sm.add_edges_from([(nodes[0], nodes[1])], origin="unknown") + sm.add_edges_from([(nodes[0], nodes[2])], origin="learned") + sm.add_edges_from([(nodes[0], nodes[3])], origin="expert") + + assert (nodes[0].get_node_name(), nodes[1].get_node_name(), "unknown") in sm.edges.data("origin") + assert (nodes[0].get_node_name(), nodes[2].get_node_name(), "learned") in sm.edges.data("origin") + assert (nodes[0].get_node_name(), nodes[3].get_node_name(), "expert") in sm.edges.data("origin") + + +class TestDynamicStructureModelAddWeightedEdgesFrom: + def test_add_weighted_edges_from_default(self): + """edges added with default origin should be identified as unknown origin""" + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + sm.add_weighted_edges_from(edges) + + assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v, w in edges) + + def test_add_weighted_edges_from_unknown(self): + """edges added with unknown origin should be labelled as unknown origin""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + sm.add_weighted_edges_from(edges, origin="unknown") + + assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v, w in edges) + + def test_add_weighted_edges_from_learned(self): + """edges added with learned origin should be labelled as learned origin""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + sm.add_weighted_edges_from(edges, origin="learned") + + assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u.get_node_name(), v.get_node_name(), "learned") in sm.edges.data("origin") for u, v, w in edges) + + def test_add_weighted_edges_from_expert(self): + """edges added with expert origin should be labelled as expert origin""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + #edges = [(1, 2, 0.5), (2, 3, 0.5)] + edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + sm.add_weighted_edges_from(edges, origin="expert") + + assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u.get_node_name(), v.get_node_name(), "expert") in sm.edges.data("origin") for u, v, w in edges) + + def test_add_weighted_edges_from_other(self): + """edges added with other origin should throw an error""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] + + with pytest.raises(ValueError, match="^Unknown origin: must be one of.*$"): + sm.add_weighted_edges_from([(nodes[0], nodes[1], 0.5)], origin="other") + + def test_add_weighted_edges_from_custom_attr(self): + """it should be possible to add edges with custom attributes""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + #edges = [(1, 2, 0.5), (2, 3, 0.5)] + edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + sm.add_weighted_edges_from(edges, x="Y") + + assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u.get_node_name(), v.get_node_name(), "Y") in sm.edges.data("x") for u, v, _ in edges) + + def test_add_weighted_edges_from_multiple_times(self): + """adding edges again should update the edges origin attr""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + + #edges = [(1, 2, 0.5), (2, 3, 0.5)] + edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + + sm.add_weighted_edges_from(edges, origin="unknown") + assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v, _ in edges) + + sm.add_weighted_edges_from(edges, origin="learned") + assert all((u.get_node_name(), v.get_node_name(), "learned") in sm.edges.data("origin") for u, v, _ in edges) + + def test_add_multiple_weighted_edges(self): + """it should be possible to add multiple edges with different origins""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)] + sm.add_weighted_edges_from([(nodes[0], nodes[1], 0.5)], origin="unknown") + sm.add_weighted_edges_from([(nodes[0], nodes[2], 0.5)], origin="learned") + sm.add_weighted_edges_from([(nodes[0], nodes[3], 0.5)], origin="expert") + + assert ('1_lag0', '2_lag0', "unknown") in sm.edges.data("origin") + assert ('1_lag0', '3_lag0', "learned") in sm.edges.data("origin") + assert ('1_lag0', '4_lag0', "expert") in sm.edges.data("origin") + + +class TestDynamicStructureModelRemoveEdgesBelowThreshold: + def test_remove_edges_below_threshold(self): + """Edges whose weight is less than a defined threshold should be removed""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)] + + strong_edges = [(nodes[0], nodes[1], 1.0), (nodes[0], nodes[2], 0.8), (nodes[0], nodes[4], 2.0)] + weak_edges = [(nodes[0], nodes[3], 0.4), (nodes[1], nodes[2], 0.6), (nodes[2], nodes[4], 0.5)] + sm.add_weighted_edges_from(strong_edges) + sm.add_weighted_edges_from(weak_edges) + + sm.remove_edges_below_threshold(0.7) + assert set(sm.edges(data="weight")) == set((u.get_node_name(), v.get_node_name(), w) for u, v, w in strong_edges) + + def test_negative_weights(self): + """Negative edges whose absolute value is greater than the defined threshold should not be removed""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)] + + strong_edges = [(nodes[0], nodes[1], -3.0), (nodes[2], nodes[0], 0.7), (nodes[0], nodes[4], -2.0)] + weak_edges = [(nodes[0], nodes[3], 0.4), (nodes[1], nodes[2], -0.6), (nodes[2], nodes[4], -0.5)] + + sm.add_weighted_edges_from(strong_edges) + sm.add_weighted_edges_from(weak_edges) + + sm.remove_edges_below_threshold(0.7) + + assert set(sm.edges(data="weight")) == set((u.get_node_name(), v.get_node_name(), w) for u, v, w in strong_edges) + + def test_equal_weights(self): + """Edges whose absolute value is equal to the defined threshold should not be removed""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)] + + strong_edges = [(nodes[0], nodes[1], 1.0), (nodes[0], nodes[4], 2.0)] + equal_edges = [(nodes[0], nodes[2], 0.6), (nodes[1], nodes[2], 0.6)] + weak_edges = [(nodes[0], nodes[3], 0.4), (nodes[2], nodes[4], 0.5)] + sm.add_weighted_edges_from(strong_edges) + sm.add_weighted_edges_from(equal_edges) + sm.add_weighted_edges_from(weak_edges) + + sm.remove_edges_below_threshold(0.6) + + assert set(sm.edges(data="weight")) == set.union( + set((u.get_node_name(), v.get_node_name(), w) for u, v, w in strong_edges), + set((u.get_node_name(), v.get_node_name(), w) for u, v, w in equal_edges) + ) + + def test_graph_with_no_edges(self): + """Can still run even if the graph is without edges""" + sm = DynamicStructureModel() + # (var, lag) - all nodes here are in current timestep + nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] + + sm.add_nodes(nodes) + sm.remove_edges_below_threshold(0.6) + + assert set(sm.nodes) == set([node.get_node_name() for node in nodes]) + assert set(sm.edges) == set() + + +class TestDynamicStructureModelGetLargestSubgraph: + @pytest.mark.parametrize( + "test_input, expected", + [ + ([(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0))]), + ([(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0))]), + + #([(0, 1), (1, 2), (1, 3), (4, 6)], [(0, 1), (1, 2), (1, 3)]), + #([(3, 4), (3, 5), (7, 6)], [(3, 4), (3, 5)]), + ], + ) + def test_get_largest_subgraph(self, test_input, expected): + """Should be able to return the largest subgraph""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + largest_subgraph = sm.get_largest_subgraph() + + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_more_than_one_largest(self): + """Return the first largest when there are more than one largest subgraph""" + + nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)] + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[3], nodes[4]), (nodes[3], nodes[5])] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + largest_subgraph = sm.get_largest_subgraph() + + expected_edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_empty(self): + """Should return None if the structure model is empty""" + + sm = DynamicStructureModel() + assert sm.get_largest_subgraph() is None + + def test_isolates(self): + """Should return None if the structure model only contains isolates""" + + sm = DynamicStructureModel() + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), + DynamicStructureNode(2, 0), DynamicStructureNode(7, 0)] + sm.add_nodes(nodes) + assert sm.get_largest_subgraph() is None + + def test_isolates_nodes_and_edges(self): + """Should be able to return the largest subgraph""" + + nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0), DynamicStructureNode(7, 0), DynamicStructureNode(8, 0), DynamicStructureNode(9, 0)] + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[5], nodes[6])] + isolated_nodes = [nodes[7], nodes[8], nodes[9]] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + sm.add_nodes(isolated_nodes) + largest_subgraph = sm.get_largest_subgraph() + + expected_edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3])] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_different_origins_and_weights(self): + """The largest subgraph returned should still have the edge data preserved from the original graph""" + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0)] + sm = DynamicStructureModel() + sm.add_weighted_edges_from([(nodes[0], nodes[1], 2.0)], origin="unknown") + sm.add_weighted_edges_from([(nodes[0], nodes[2], 1.0)], origin="learned") + sm.add_weighted_edges_from([(nodes[3], nodes[4], 0.7)], origin="expert") + + largest_subgraph = sm.get_largest_subgraph() + + assert set(largest_subgraph.edges.data("origin")) == { + (nodes[0].get_node_name(), nodes[1].get_node_name(), "unknown"), + (nodes[0].get_node_name(), nodes[2].get_node_name(), "learned"), + } + assert set(largest_subgraph.edges.data("weight")) == {(nodes[0].get_node_name(), nodes[1].get_node_name(), 2.0), (nodes[0].get_node_name(), nodes[2].get_node_name(), 1.0)} + + +class TestDynamicStructureModelGetTargetSubgraph: + @pytest.mark.parametrize( + "target_node, test_input, expected", + [ + (DynamicStructureNode(1, 0), [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0))]), + (DynamicStructureNode(3, 0), [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0))]), + (DynamicStructureNode(7, 0), [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(1, 0))], [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))]), + ], + ) + def test_get_target_subgraph(self, target_node, test_input, expected): + """Should be able to return the subgraph with the specified node""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + subgraph = sm.get_target_subgraph(target_node) + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "target_node, test_input, expected", + [ + ( + DynamicStructureNode('a', 0), + [(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0)), (DynamicStructureNode('e', 0), DynamicStructureNode('f', 0))], + [(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0))], + ), + ( + DynamicStructureNode('g', 0), + [(DynamicStructureNode('g', 0), DynamicStructureNode('h', 0)), (DynamicStructureNode('g', 0), DynamicStructureNode('z', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0))], + [(DynamicStructureNode('g', 0), DynamicStructureNode('h', 0)), (DynamicStructureNode('g', 0), DynamicStructureNode('z', 0))], + ), + ], + ) + def test_get_subgraph_string(self, target_node, test_input, expected): + """Should be able to return the subgraph with the specified node""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + subgraph = sm.get_target_subgraph(target_node) + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "target_node, test_input", + [ + ( + DynamicStructureNode(7, 0), + [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))] + ), + ( + DynamicStructureNode(1, 0), + [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))] + ) + ], + ) + def test_node_not_in_graph(self, target_node, test_input): + """Should raise an error if the target_node is not found in the graph""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + + with pytest.raises( + NodeNotFound, + match=re.escape(f"Node {target_node} not found in the graph"), + ): + sm.get_target_subgraph(target_node) + + def test_isolates(self): + """Should return an isolated node""" + + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), + DynamicStructureNode(2, 0), DynamicStructureNode(7, 0)] + sm = DynamicStructureModel() + sm.add_nodes(nodes) + subgraph = sm.get_target_subgraph(DynamicStructureNode(1, 0)) + expected_graph = DynamicStructureModel() + expected_graph.add_node(DynamicStructureNode(1, 0)) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + def test_isolates_nodes_and_edges(self): + """Should be able to return the subgraph with the specified node""" + + nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0), DynamicStructureNode(7, 0), DynamicStructureNode(8, 0), DynamicStructureNode(9, 0)] + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[5], nodes[6]), (nodes[4], nodes[5])] + isolated_nodes = [nodes[7], nodes[8], nodes[9]] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + sm.add_nodes(isolated_nodes) + subgraph = sm.get_target_subgraph(nodes[5]) + expected_edges = [(nodes[5], nodes[6]), (nodes[4], nodes[5])] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + def test_different_origins_and_weights(self): + """The subgraph returned should still have the edge data preserved from the original graph""" + + sm = DynamicStructureModel() + sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") + sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") + sm.add_weighted_edges_from([(5, 6, 0.7)], origin="expert") + + subgraph = sm.get_target_subgraph(2) + + assert set(subgraph.edges.data("origin")) == { + (1, 2, "unknown"), + (1, 3, "learned"), + } + assert set(subgraph.edges.data("weight")) == {(1, 2, 2.0), (1, 3, 1.0)} + + def test_instance_type(self): + """The subgraph returned should still be a DynamicStructureModel instance""" + + sm = DynamicStructureModel() + sm.add_edges_from([(0, 1), (1, 2), (1, 3), (4, 6)]) + subgraph = sm.get_target_subgraph(2) + + assert isinstance(subgraph, DynamicStructureModel) + + def test_get_target_subgraph_twice(self): + """get_target_subgraph should be able to run more than once""" + + sm = DynamicStructureModel() + sm.add_edges_from([(0, 1), (1, 2), (1, 3), (4, 6)]) + + subgraph = sm.get_target_subgraph(0) + subgraph.remove_edge(0, 1) + subgraph = subgraph.get_target_subgraph(1) + + expected_graph = DynamicStructureModel() + expected_edges = [(1, 2), (1, 3)] + expected_graph.add_edges_from(expected_edges) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + +class TestDynamicStructureModelGetMarkovBlanket: + @pytest.mark.parametrize( + "target_node, test_input, expected", + [ + (1, [(0, 1), (1, 2), (1, 3), (4, 5)], [(0, 1), (1, 2), (1, 3)]), + (1, [(0, 1), (1, 2), (1, 3), (4, 3)], [(0, 1), (1, 2), (1, 3), (4, 3)]), + (3, [(3, 4), (3, 5), (6, 7)], [(3, 4), (3, 5)]), + (7, [(7, 8), (1, 2), (6, 7), (2, 3), (5, 8)], [(7, 8), (6, 7), (5, 8)]), + ], + ) + def test_get_markov_blanket_single(self, target_node, test_input, expected): + """Should be able to return Markov blanket with the specified single node""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + blanket = sm.get_markov_blanket(target_node) + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(blanket.nodes) == set(expected_graph.nodes) + assert set(blanket.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "target_nodes, test_input, expected", + [ + ( + [1, 4], + [(0, 1), (1, 2), (1, 3), (4, 5)], + [(0, 1), (1, 2), (1, 3), (4, 5)], + ), + ([2, 4], [(0, 1), (1, 2), (1, 3), (4, 3)], [(1, 2), (1, 3), (4, 3)]), + ([3, 6], [(3, 4), (3, 5), (6, 7)], [(3, 4), (3, 5), (6, 7)]), + ( + [2, 5], + [(7, 8), (1, 2), (6, 7), (2, 3), (5, 8)], + [(1, 2), (2, 3), (7, 8), (5, 8)], + ), + ], + ) + def test_get_markov_blanket_multiple(self, target_nodes, test_input, expected): + """Should be able to return Markov blanket with the specified list of nodes""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + blanket = sm.get_markov_blanket(target_nodes) + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(blanket.nodes) == set(expected_graph.nodes) + assert set(blanket.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "target_node, test_input, expected", + [ + ( + "a", + [("a", "b"), ("a", "c"), ("c", "d"), ("e", "f")], + [("a", "b"), ("a", "c")], + ), + ( + "g", + [("g", "h"), ("g", "z"), ("a", "b"), ("a", "c"), ("c", "d")], + [("g", "h"), ("g", "z")], + ), + ], + ) + def test_get_markov_blanket_string(self, target_node, test_input, expected): + """Should be able to return the subgraph with the specified node""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + blanket = sm.get_markov_blanket(target_node) + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(blanket.nodes) == set(expected_graph.nodes) + assert set(blanket.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "target_node, test_input", + [ + (7, [(0, 1), (1, 2), (1, 3), (4, 6)]), + (1, [(3, 4), (3, 5), (7, 6)]), + ([1, 7], [(0, 1), (1, 2), (1, 3), (4, 6)]), + ([8, 2], [(3, 4), (3, 5), (7, 6)]), + ], + ) + def test_node_not_in_graph(self, target_node, test_input): + """Should raise an error if the target_node is not found in the graph""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + + with pytest.raises( + NodeNotFound, + match=f"Node {target_node} not found in the graph", + ): + sm.get_markov_blanket(target_node) + + def test_isolates(self): + """Should return an isolated node""" + + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), + DynamicStructureNode(2, 0), DynamicStructureNode(7, 0)] + sm = DynamicStructureModel() + sm.add_nodes(nodes) + blanket = sm.get_markov_blanket(1) + + expected_graph = DynamicStructureModel() + expected_graph.add_node(1) + + assert set(blanket.nodes) == set(expected_graph.nodes) + assert set(blanket.edges) == set(expected_graph.edges) + + def test_isolates_nodes_and_edges(self): + """Should be able to return the subgraph with the specified node""" + + edges = [(0, 1), (1, 2), (1, 3), (5, 6), (4, 5)] + isolated_nodes = [DynamicStructureNode(7, 0), DynamicStructureNode(8, 0), DynamicStructureNode(9, 0)] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + sm.add_nodes(isolated_nodes) + subgraph = sm.get_markov_blanket(5) + expected_edges = [(5, 6), (4, 5)] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + def test_instance_type(self): + """The subgraph returned should still be a DynamicStructureModel instance""" + + sm = DynamicStructureModel() + sm.add_edges_from([(0, 1), (1, 2), (1, 3), (4, 6)]) + subgraph = sm.get_markov_blanket(DynamicStructureNode(2, 0)) + + assert isinstance(subgraph, DynamicStructureModel) From 0c3fd1924af9a240e979d23fb2446af77b9a5e45 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Mon, 3 Oct 2022 21:45:07 -0500 Subject: [PATCH 04/13] fix tests in dynamic structure model --- causalnex/structure/__init__.py | 5 +- causalnex/structure/structuremodel.py | 50 ++++----- tests/structure/test_dynamicstructuremodel.py | 104 ++++++++++-------- 3 files changed, 79 insertions(+), 80 deletions(-) diff --git a/causalnex/structure/__init__.py b/causalnex/structure/__init__.py index 4980400..190b745 100644 --- a/causalnex/structure/__init__.py +++ b/causalnex/structure/__init__.py @@ -35,7 +35,6 @@ "notears", "dynotears", "data_generators", - "node", "DAGRegressor", "DAGClassifier", "DynamicStructureModel", @@ -43,6 +42,4 @@ ] from .pytorch.sklearn import DAGClassifier, DAGRegressor -from .structuremodel import StructureModel -from .structuremodel import DynamicStructureModel -from .structuremodel import DynamicStructureNode \ No newline at end of file +from .structuremodel import StructureModel, DynamicStructureModel, DynamicStructureNode \ No newline at end of file diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 2b9af2a..4ec6c6c 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -98,7 +98,6 @@ def __init__(self, incoming_graph_data=None, origin="unknown", **attr): super().__init__(incoming_graph_data, **attr) for u_of_edge, v_of_edge in self.edges: - print(f'in for loop in init {u_of_edge}, {v_of_edge}') self[u_of_edge][v_of_edge]["origin"] = origin def to_directed_class(self): @@ -340,23 +339,15 @@ class DynamicStructureNode(NamedTuple): node: Union[int, str] time_step: int - @classmethod - def __instancecheck__(cls, instance): - print(f'inside instance check function {type(instance)}') - if hasattr(instance, 'node') and hasattr(instance, 'time_step') and hasattr(instance, 'get_node_name'): - return True - else: - return False - def get_node_name(self): return f'{self.node}_lag{self.time_step}' - def __eq__(self, other): - if isinstance(other, DynamicStructureNode): - return self.get_node_name() == other.get_node_name() - return False def checkargs(function): + """ + This function ensures the arguments passed to the methods in ``DynamicStructureModel`` are of the correct type. + Specifically that they are of type ``DynamicStructureNode``. + """ def _f(*arguments, **attr): for index, argument in enumerate(inspect.getfullargspec(function)[0]): if argument == 'self': @@ -377,6 +368,9 @@ def _f(*arguments, **attr): elif isinstance(arguments[index], types.GeneratorType): # this comes from networkx, coerce into correct types pass + elif hasattr(function.__annotations__[argument], '__args__'): + if not isinstance(arguments[index], function.__annotations__[argument].__args__[0]): + raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) elif not isinstance(arguments[index], function.__annotations__[argument]): raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) except IndexError as e: @@ -463,7 +457,6 @@ def get_target_subgraph(self, node: DynamicStructureNode) -> "DynamicStructureMo """ node_name = node.get_node_name() if node_name in self.nodes: - print(f'node {node} in self nodes {self.nodes}') for component in nx.weakly_connected_components(self): subgraph = self.subgraph(component).copy() @@ -492,27 +485,28 @@ def get_markov_blanket( nodes = [nodes] blanket_nodes = set() - + for node in set(nodes): # Ensure target nodes are unique - if node not in set(self.nodes): - raise NodeNotFound(f"Node {node} not found in the graph.") + node_name = node.get_node_name() + if node_name not in set(self.nodes): + raise NodeNotFound(f"Node {node} not found in the graph") - blanket_nodes.add(node) - blanket_nodes.update(self.predecessors(node)) + blanket_nodes.add(node_name) + blanket_nodes.update(self.predecessors(node_name)) - for child in self.successors(node): + for child in self.successors(node_name): blanket_nodes.add(child) blanket_nodes.update(self.predecessors(child)) blanket = DynamicStructureModel() - blanket.add_nodes(blanket_nodes) - blanket.add_weighted_edges_from( - [ - (u, v, w) - for u, v, w in self.edges(data="weight") - if u in blanket_nodes and v in blanket_nodes - ] - ) + blanket_dyn_nodes = [DynamicStructureNode(node_name[0], node_name[-1]) for node_name in blanket_nodes] + blanket.add_nodes(blanket_dyn_nodes) + + blanket_weighted_edges = [] + for u, v, w in self.edges(data="weight"): + if u in blanket_nodes and v in blanket_nodes: + blanket_weighted_edges.append((DynamicStructureNode(u[0], u[-1]), DynamicStructureNode(v[0], v[-1]), w)) + blanket.add_weighted_edges_from(blanket_weighted_edges) return blanket # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) diff --git a/tests/structure/test_dynamicstructuremodel.py b/tests/structure/test_dynamicstructuremodel.py index b7f7741..1910617 100644 --- a/tests/structure/test_dynamicstructuremodel.py +++ b/tests/structure/test_dynamicstructuremodel.py @@ -100,7 +100,7 @@ def test_to_undirected(self): sm.add_edges_from(edges) udg = sm.to_undirected() - print(f'udg edges {udg.edges}') + assert all((edge[0].get_node_name(), edge[1].get_node_name()) in udg.edges for edge in [(nodes[1], nodes[2]), (nodes[2], nodes[3])]) assert (nodes[0].get_node_name(), nodes[1].get_node_name()) in udg.edges or (nodes[1].get_node_name(), nodes[0].get_node_name()) in udg.edges assert len(udg.edges) == 3 @@ -639,24 +639,27 @@ def test_different_origins_and_weights(self): """The subgraph returned should still have the edge data preserved from the original graph""" sm = DynamicStructureModel() - sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") - sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") - sm.add_weighted_edges_from([(5, 6, 0.7)], origin="expert") + nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0)] + + sm.add_weighted_edges_from([(nodes[0], nodes[1], 2.0)], origin="unknown") + sm.add_weighted_edges_from([(nodes[0], nodes[2], 1.0)], origin="learned") + sm.add_weighted_edges_from([(nodes[3], nodes[4], 0.7)], origin="expert") - subgraph = sm.get_target_subgraph(2) + subgraph = sm.get_target_subgraph(nodes[1]) assert set(subgraph.edges.data("origin")) == { - (1, 2, "unknown"), - (1, 3, "learned"), + (nodes[0].get_node_name(), nodes[1].get_node_name(), "unknown"), + (nodes[0].get_node_name(), nodes[2].get_node_name(), "learned"), } - assert set(subgraph.edges.data("weight")) == {(1, 2, 2.0), (1, 3, 1.0)} + assert set(subgraph.edges.data("weight")) == {(nodes[0].get_node_name(), nodes[1].get_node_name(), 2.0), (nodes[0].get_node_name(), nodes[2].get_node_name(), 1.0)} def test_instance_type(self): """The subgraph returned should still be a DynamicStructureModel instance""" sm = DynamicStructureModel() - sm.add_edges_from([(0, 1), (1, 2), (1, 3), (4, 6)]) - subgraph = sm.get_target_subgraph(2) + nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)] + sm.add_edges_from([(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[4], nodes[5])]) + subgraph = sm.get_target_subgraph(nodes[2]) assert isinstance(subgraph, DynamicStructureModel) @@ -664,14 +667,15 @@ def test_get_target_subgraph_twice(self): """get_target_subgraph should be able to run more than once""" sm = DynamicStructureModel() - sm.add_edges_from([(0, 1), (1, 2), (1, 3), (4, 6)]) + nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)] + sm.add_edges_from([(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[4], nodes[5])]) - subgraph = sm.get_target_subgraph(0) - subgraph.remove_edge(0, 1) - subgraph = subgraph.get_target_subgraph(1) + subgraph = sm.get_target_subgraph(nodes[0]) + subgraph.remove_edge(nodes[0].get_node_name(), nodes[1].get_node_name()) + subgraph = subgraph.get_target_subgraph(nodes[1]) expected_graph = DynamicStructureModel() - expected_edges = [(1, 2), (1, 3)] + expected_edges = [(nodes[1], nodes[2]), (nodes[1], nodes[3])] expected_graph.add_edges_from(expected_edges) assert set(subgraph.nodes) == set(expected_graph.nodes) @@ -681,11 +685,11 @@ def test_get_target_subgraph_twice(self): class TestDynamicStructureModelGetMarkovBlanket: @pytest.mark.parametrize( "target_node, test_input, expected", - [ - (1, [(0, 1), (1, 2), (1, 3), (4, 5)], [(0, 1), (1, 2), (1, 3)]), - (1, [(0, 1), (1, 2), (1, 3), (4, 3)], [(0, 1), (1, 2), (1, 3), (4, 3)]), - (3, [(3, 4), (3, 5), (6, 7)], [(3, 4), (3, 5)]), - (7, [(7, 8), (1, 2), (6, 7), (2, 3), (5, 8)], [(7, 8), (6, 7), (5, 8)]), + [ + (DynamicStructureNode(1, 0), [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0))], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0))]), + (DynamicStructureNode(1, 0), [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0))], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0))]), + (DynamicStructureNode(3, 0), [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0))], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0))]), + (DynamicStructureNode(7, 0), [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0))], [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0))]), ], ) def test_get_markov_blanket_single(self, target_node, test_input, expected): @@ -704,16 +708,16 @@ def test_get_markov_blanket_single(self, target_node, test_input, expected): "target_nodes, test_input, expected", [ ( - [1, 4], - [(0, 1), (1, 2), (1, 3), (4, 5)], - [(0, 1), (1, 2), (1, 3), (4, 5)], + [DynamicStructureNode(1, 0), DynamicStructureNode(4, 0)], + [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0))], + [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0))], ), - ([2, 4], [(0, 1), (1, 2), (1, 3), (4, 3)], [(1, 2), (1, 3), (4, 3)]), - ([3, 6], [(3, 4), (3, 5), (6, 7)], [(3, 4), (3, 5), (6, 7)]), + ([DynamicStructureNode(2, 0), DynamicStructureNode(4, 0)], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0))], [(DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0))]), + ([DynamicStructureNode(3, 0), DynamicStructureNode(6, 0)], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0))], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0))]), ( - [2, 5], - [(7, 8), (1, 2), (6, 7), (2, 3), (5, 8)], - [(1, 2), (2, 3), (7, 8), (5, 8)], + [DynamicStructureNode(2, 0), DynamicStructureNode(5, 0)], + [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0))], + [(DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0))], ), ], ) @@ -733,14 +737,14 @@ def test_get_markov_blanket_multiple(self, target_nodes, test_input, expected): "target_node, test_input, expected", [ ( - "a", - [("a", "b"), ("a", "c"), ("c", "d"), ("e", "f")], - [("a", "b"), ("a", "c")], + DynamicStructureNode('a', 0), + [(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0)), (DynamicStructureNode('e', 0), DynamicStructureNode('f', 0))], + [(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0))], ), ( - "g", - [("g", "h"), ("g", "z"), ("a", "b"), ("a", "c"), ("c", "d")], - [("g", "h"), ("g", "z")], + DynamicStructureNode('g', 0), + [(DynamicStructureNode('g', 0), DynamicStructureNode('h', 0)), (DynamicStructureNode('g', 0), DynamicStructureNode('z', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0))], + [(DynamicStructureNode('g', 0), DynamicStructureNode('h', 0)), (DynamicStructureNode('g', 0), DynamicStructureNode('z', 0))], ), ], ) @@ -759,10 +763,10 @@ def test_get_markov_blanket_string(self, target_node, test_input, expected): @pytest.mark.parametrize( "target_node, test_input", [ - (7, [(0, 1), (1, 2), (1, 3), (4, 6)]), - (1, [(3, 4), (3, 5), (7, 6)]), - ([1, 7], [(0, 1), (1, 2), (1, 3), (4, 6)]), - ([8, 2], [(3, 4), (3, 5), (7, 6)]), + (DynamicStructureNode(7, 0), [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))]), + (DynamicStructureNode(1, 0), [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))]), + #([DynamicStructureNode(1, 0), DynamicStructureNode(7, 0)], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))]), + #([DynamicStructureNode(8, 0), DynamicStructureNode(2, 0)], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))]), ], ) def test_node_not_in_graph(self, target_node, test_input): @@ -773,7 +777,7 @@ def test_node_not_in_graph(self, target_node, test_input): with pytest.raises( NodeNotFound, - match=f"Node {target_node} not found in the graph", + match=re.escape(f"Node {target_node} not found in the graph"), ): sm.get_markov_blanket(target_node) @@ -784,10 +788,10 @@ def test_isolates(self): DynamicStructureNode(2, 0), DynamicStructureNode(7, 0)] sm = DynamicStructureModel() sm.add_nodes(nodes) - blanket = sm.get_markov_blanket(1) + blanket = sm.get_markov_blanket(nodes[0]) expected_graph = DynamicStructureModel() - expected_graph.add_node(1) + expected_graph.add_node(nodes[0]) assert set(blanket.nodes) == set(expected_graph.nodes) assert set(blanket.edges) == set(expected_graph.edges) @@ -795,13 +799,16 @@ def test_isolates(self): def test_isolates_nodes_and_edges(self): """Should be able to return the subgraph with the specified node""" - edges = [(0, 1), (1, 2), (1, 3), (5, 6), (4, 5)] - isolated_nodes = [DynamicStructureNode(7, 0), DynamicStructureNode(8, 0), DynamicStructureNode(9, 0)] + nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0), + DynamicStructureNode(7, 0), DynamicStructureNode(8, 0), DynamicStructureNode(9, 0)] + edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[5], nodes[6]), (nodes[4], nodes[5])] + isolated_nodes = [nodes[7], nodes[8], nodes[9]] sm = DynamicStructureModel() sm.add_edges_from(edges) sm.add_nodes(isolated_nodes) - subgraph = sm.get_markov_blanket(5) - expected_edges = [(5, 6), (4, 5)] + subgraph = sm.get_markov_blanket(nodes[5]) + expected_edges = [(nodes[5], nodes[6]), (nodes[4], nodes[5])] expected_graph = DynamicStructureModel() expected_graph.add_edges_from(expected_edges) @@ -810,9 +817,10 @@ def test_isolates_nodes_and_edges(self): def test_instance_type(self): """The subgraph returned should still be a DynamicStructureModel instance""" - + nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)] sm = DynamicStructureModel() - sm.add_edges_from([(0, 1), (1, 2), (1, 3), (4, 6)]) - subgraph = sm.get_markov_blanket(DynamicStructureNode(2, 0)) + sm.add_edges_from([(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[4], nodes[5])]) + subgraph = sm.get_markov_blanket(nodes[2]) assert isinstance(subgraph, DynamicStructureModel) From 71b2f93573e4cd31ac0af4a960bb799775feb883 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Thu, 6 Oct 2022 19:22:43 -0500 Subject: [PATCH 05/13] dynotears tests passing --- causalnex/structure/dynotears.py | 4 ++-- causalnex/structure/structuremodel.py | 9 +++++++++ tests/structure/test_dynotears.py | 2 ++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/causalnex/structure/dynotears.py b/causalnex/structure/dynotears.py index 65231ba..7e1c309 100644 --- a/causalnex/structure/dynotears.py +++ b/causalnex/structure/dynotears.py @@ -130,8 +130,8 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments sm.add_weighted_edges_from( [ ( - _format_name_from_pandas(idx_col, u), - _format_name_from_pandas(idx_col, v), + DynamicStructureNode(idx_col[int(u[0])], u[-1]), # _format_name_from_pandas(idx_col, u), idx_col[int(u[0])] + DynamicStructureNode(idx_col[int(v[0])], v[-1]), w, ) for u, v, w in g.edges.data("weight") diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 4ec6c6c..6f76518 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -32,6 +32,7 @@ """ from typing import Any, Hashable, List, Set, Tuple, Union, NamedTuple +from collections import Iterable import networkx as nx import numpy as np @@ -545,6 +546,10 @@ def add_edges_from( """ _validate_origin(origin) + if isinstance(ebunch_to_add, Iterable) and not ebunch_to_add: + super().add_edges_from(ebunch_to_add, **attr) + return + if isinstance(ebunch_to_add, types.GeneratorType): dsn_ebunch = [] for e in ebunch_to_add: @@ -599,6 +604,10 @@ def add_weighted_edges_from( """ _validate_origin(origin) + if isinstance(ebunch_to_add, Iterable) and not ebunch_to_add: + super().add_weighted_edges_from(ebunch_to_add, weight=weight, **attr) + return + if isinstance(ebunch_to_add, types.GeneratorType): dsn_ebunch = [(DynamicStructureNode(e[0][0], e[0][-1]).get_node_name(), DynamicStructureNode(e[1][0], e[1][-1]).get_node_name(), e[2]) for e in ebunch_to_add] else: diff --git a/tests/structure/test_dynotears.py b/tests/structure/test_dynotears.py index 31094e6..9ce5771 100644 --- a/tests/structure/test_dynotears.py +++ b/tests/structure/test_dynotears.py @@ -249,6 +249,8 @@ def test_all_columns_in_structure(self, data_dynotears_p2): def test_isolated_nodes_exist(self, data_dynotears_p2): """Isolated nodes should still be in the learned structure""" + # X is the current time step of 5 features + # Y is the previous 2 time steps of 5 features sm = from_numpy_dynamic( data_dynotears_p2["X"], data_dynotears_p2["Y"], w_threshold=1 ) From 7d495fc8c796a17be6486184256dd2fe552e019f Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Fri, 14 Oct 2022 11:20:39 -0500 Subject: [PATCH 06/13] 100% test coverage --- causalnex/network/__init__.py | 1 - causalnex/network/network.py | 83 +--- causalnex/structure/dynotears.py | 16 +- causalnex/structure/structuremodel.py | 251 +++++----- tests/structure/test_dynamicstructuremodel.py | 445 ++++++++++++++---- tests/structure/test_dynotears.py | 63 +-- tests/test_dynamicbayesiannetwork.py | 38 -- 7 files changed, 519 insertions(+), 378 deletions(-) delete mode 100644 tests/test_dynamicbayesiannetwork.py diff --git a/causalnex/network/__init__.py b/causalnex/network/__init__.py index 0d9de37..8b94a7a 100644 --- a/causalnex/network/__init__.py +++ b/causalnex/network/__init__.py @@ -33,4 +33,3 @@ __all__ = ["BayesianNetwork"] from .network import BayesianNetwork -from .network import DynamicBayesianNetwork diff --git a/causalnex/network/network.py b/causalnex/network/network.py index 3408059..11ba824 100644 --- a/causalnex/network/network.py +++ b/causalnex/network/network.py @@ -735,85 +735,4 @@ def _predict_probability_from_incomplete_data( probability = probability[cols] probability.columns = cols - return probability - - -class DynamicBayesianNetwork(BayesianNetwork): - """ - Base class for Dynamic Bayesian Network (DBN), a probabilistic weighted DAG where nodes represent variables, - edges represent the causal relationships between variables. - - ``DynamicBayesianNetwork`` stores nodes with their possible states, edges and - conditional probability distributions (CPDs) of each node. - - ``DynamicBayesianNetwork`` is built on top of the ``StructureModel``, which is an extension of ``networkx.DiGraph`` - (see :func:`causalnex.structure.structuremodel.StructureModel`). - - In order to define the ``DynamicBayesianNetwork``, users should provide a relevant ``StructureModel``. - Once ``DynamicBayesianNetwork`` is initialised, no changes to the ``StructureModel`` can be made - and CPDs can be learned from the data. - - The learned CPDs can be then used for likelihood estimation and predictions. - - Example: - :: - >>> # Create a Dynamic Bayesian Network with a manually defined DAG. - >>> from causalnex.structure import StructureModel - >>> from causalnex.network import DynamicBayesianNetwork - >>> - >>> sm = StructureModel() - >>> sm.add_edges_from([ - >>> ('rush_hour', 'traffic'), - >>> ('weather', 'traffic') - >>> ]) - >>> dbn = DynamicBayesianNetwork(sm) - >>> # A created ``DynamicBayesianNetwork`` stores nodes and edges defined by the ``StructureModel`` - >>> dbn.nodes - ['rush_hour', 'traffic', 'weather'] - >>> - >>> dbn.edges - [('rush_hour', 'traffic'), ('weather', 'traffic')] - >>> # A ``DynamicBayesianNetwork`` doesn't store any CPDs yet - >>> dbn.cpds - >>> {} - >>> - >>> # Learn the nodes' states from the data - >>> import pandas as pd - >>> data = pd.DataFrame({ - >>> 'rush_hour': [True, False, False, False, True, False, True], - >>> 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], - >>> 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] - >>> }) - >>> dbn = dbn.fit_node_states(data) - >>> dbn.node_states - {'rush_hour': {False, True}, 'weather': {'Bad', 'Good', 'Terrible'}, 'traffic': {'heavy', 'light'}} - >>> # Learn the CPDs from the data - >>> dbn = dbn.fit_cpds(data) - >>> # Use the learned CPDs to make predictions on the unseen data - >>> test_data = pd.DataFrame({ - >>> 'rush_hour': [False, False, True, True], - >>> 'weather': ['Good', 'Bad', 'Good', 'Bad'] - >>> }) - >>> dbn.predict(test_data, "traffic").to_dict() - >>> {'traffic_prediction': {0: 'light', 1: 'heavy', 2: 'heavy', 3: 'heavy'}} - >>> dbn.predict_probability(test_data, "traffic").to_dict() - {'traffic_prediction': {0: 'light', 1: 'heavy', 2: 'heavy', 3: 'heavy'}} - {'traffic_light': {0: 0.75, 1: 0.25, 2: 0.3333333333333333, 3: 0.3333333333333333}, - 'traffic_heavy': {0: 0.25, 1: 0.75, 2: 0.6666666666666666, 3: 0.6666666666666666}} - """ - - def __init__(self, structure: DynamicStructureModel): - """ - Create a ``DynamicBayesianNetwork`` with a DAG defined by ``DynamicStructureModel``. - - Args: - structure: a graph representing a causal relationship between variables. - In the structure - - cycles are not allowed; - - multiple (parallel) edges are not allowed; - - isolated nodes and multiple components are not allowed. - - Raises: - ValueError: If the structure is not a connected DAG. - """ - super().__init__(structure) \ No newline at end of file + return probability \ No newline at end of file diff --git a/causalnex/structure/dynotears.py b/causalnex/structure/dynotears.py index 7e1c309..31f2ac6 100644 --- a/causalnex/structure/dynotears.py +++ b/causalnex/structure/dynotears.py @@ -130,7 +130,7 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments sm.add_weighted_edges_from( [ ( - DynamicStructureNode(idx_col[int(u[0])], u[-1]), # _format_name_from_pandas(idx_col, u), idx_col[int(u[0])] + DynamicStructureNode(idx_col[int(u[0])], u[-1]), DynamicStructureNode(idx_col[int(v[0])], v[-1]), w, ) @@ -142,20 +142,6 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments return sm -def _format_name_from_pandas(idx_col: Dict[int, str], from_numpy_node: str) -> str: - """ - Helper function for `from_pandas_dynamic`. converts a node from the `from_numpy_dynamic` format to the `from_pandas` - format - Args: - idx_col: map from variable to intdex - from_numpy_node: nodes in the structure model output by `from_numpy_dynamic`. - Returns: - nodes in from_pandas_dynamic format - """ - idx, lag_val = from_numpy_node.split("_lag") - return f"{idx_col[int(idx)]}_lag{lag_val}" - - def from_numpy_dynamic( # pylint: disable=too-many-arguments X: np.ndarray, Xlags: np.ndarray, diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 6f76518..350b626 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -31,6 +31,7 @@ ``StructureModel`` is a class that describes relationships between variables as a graph. """ +from distutils.ccompiler import new_compiler from typing import Any, Hashable, List, Set, Tuple, Union, NamedTuple from collections import Iterable @@ -295,7 +296,7 @@ def threshold_till_dag(self): self.remove_edge(i, j) def get_markov_blanket( - self, nodes: Union[Any, List[Any], Set[Any]] + self, nodes: Union[Any, List[Any], Set[Any]], cls: nx.DiGraph = None ) -> "StructureModel": """ Get Markov blanket of specified target nodes @@ -325,7 +326,10 @@ def get_markov_blanket( blanket_nodes.add(child) blanket_nodes.update(self.predecessors(child)) - blanket = StructureModel() + if cls: + blanket = cls() + else: + blanket = StructureModel() blanket.add_nodes_from(blanket_nodes) blanket.add_weighted_edges_from( [ @@ -337,6 +341,9 @@ def get_markov_blanket( return blanket class DynamicStructureNode(NamedTuple): + """ + Used by DynamicStructureModel to store each node as a (node_name, lag) pair + """ node: Union[int, str] time_step: int @@ -344,60 +351,112 @@ def get_node_name(self): return f'{self.node}_lag{self.time_step}' -def checkargs(function): +def check_collection_type(c): + return isinstance(c, (list, set, types.GeneratorType)) + +def coerce_dsm_edges(arg): + """ + Used by DynamicStructureModel to convert edges as passed as primitive tuples to tuples of ``DynamicStructureNode``s. + An example input is [((0,0), (1,0), .5), ((1,0), (2,0), .5)]. This would be converted to + [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .5)] + """ + multi_edge = check_collection_type(arg) + if multi_edge: + if isinstance(arg, types.GeneratorType): + arg = list(arg) + if not all(isinstance(e, Tuple) for e in arg): + raise TypeError(f'Edges must be tuples containing 2 or 3 elements, received {arg}') + if all(isinstance(e[0], DynamicStructureNode) and isinstance(e[1], DynamicStructureNode) for e in arg): + return arg + else: + new_arg = [] + for e in arg: + if not isinstance(e[0], Tuple) or not isinstance(e[1], Tuple): + raise TypeError(f'Nodes in {e} must be tuples with node name and time step') + elif isinstance(e[0], DynamicStructureNode) and isinstance(e[1], DynamicStructureNode): + new_arg.append(e) + elif len(e) == 2: + if not isinstance(e[0], DynamicStructureNode) and not isinstance(e[1], DynamicStructureNode): + new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), DynamicStructureNode(e[1][0], e[1][1]))) + elif not isinstance(e[0], DynamicStructureNode): + new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), e[1])) + elif not isinstance(e[1], DynamicStructureNode): + new_arg.append((e[0], DynamicStructureNode(e[1][0], e[1][1]))) + elif len(e) == 3: + if not isinstance(e[0], DynamicStructureNode) and not isinstance(e[1], DynamicStructureNode): + new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), DynamicStructureNode(e[1][0], e[1][1]), e[2])) + elif not isinstance(e[0], DynamicStructureNode): + new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), e[1], e[2])) + elif not isinstance(e[1], DynamicStructureNode): + new_arg.append((e[0], DynamicStructureNode(e[1][0], e[1][1]), e[2])) + else: + raise TypeError(f'Argument {e} must be a tuple containing 2 or 3 elements') + return new_arg + else: + # if not isinstance(arg, Tuple): + # raise TypeError(f'Edges must be tuples containing 2 or 3 elements, received {arg}') + if not isinstance(arg[0], Tuple) or not isinstance(arg[1], Tuple): + raise TypeError(f'Nodes in {arg} must be tuples with node name and time step') + elif isinstance(arg[0], DynamicStructureNode) and isinstance(arg[1], DynamicStructureNode): + return arg + elif len(arg) == 2: + if not isinstance(arg[0], DynamicStructureNode) and not isinstance(arg[1], DynamicStructureNode): + return (DynamicStructureNode(arg[0][0], arg[0][1]), DynamicStructureNode(arg[1][0], arg[1][1])) + elif not isinstance(arg[0], DynamicStructureNode): + return (DynamicStructureNode(arg[0][0], arg[0][1]), arg[1]) + elif not isinstance(arg[1], DynamicStructureNode): + return (arg[0], DynamicStructureNode(arg[1][0], arg[1][1])) + elif len(arg) == 3: + if not isinstance(arg[0], DynamicStructureNode) and not isinstance(arg[1], DynamicStructureNode): + return (DynamicStructureNode(arg[0][0], arg[0][1]), DynamicStructureNode(arg[1][0], arg[1][1]), arg[2]) + elif not isinstance(arg[0], DynamicStructureNode): + return (DynamicStructureNode(arg[0][0], arg[0][1]), arg[1], arg[2]) + elif not isinstance(arg[1], DynamicStructureNode): + return (arg[0], DynamicStructureNode(arg[1][0], arg[1][1]), arg[2]) + else: + raise TypeError(f'Argument {arg} must be either a DynamicStructureNode or tuple containing 2 or 3 elements') + +def coerce_dsm_nodes(arg): """ - This function ensures the arguments passed to the methods in ``DynamicStructureModel`` are of the correct type. - Specifically that they are of type ``DynamicStructureNode``. + Used by DynamicStructureModel to convert nodes passed as (node_name, lag) tuples into ``DynamicStructureNode``s """ - def _f(*arguments, **attr): - for index, argument in enumerate(inspect.getfullargspec(function)[0]): - if argument == 'self': - continue - try: - if isinstance(arguments[index], list): - for arg in arguments[index]: - if isinstance(arg, tuple) and not isinstance(arg, DynamicStructureNode): - if len(arg) == 3: - if not all(isinstance(n, DynamicStructureNode) for n in arg[:-1]): - raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) - else: - if not all(isinstance(n, DynamicStructureNode) for n in arg): - raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) - else: - if not isinstance(arg, function.__annotations__[argument].__args__[0]): - raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) - elif isinstance(arguments[index], types.GeneratorType): - # this comes from networkx, coerce into correct types - pass - elif hasattr(function.__annotations__[argument], '__args__'): - if not isinstance(arguments[index], function.__annotations__[argument].__args__[0]): - raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) - elif not isinstance(arguments[index], function.__annotations__[argument]): - raise TypeError("{} is not of type {}".format(arguments[index], function.__annotations__[argument])) - except IndexError as e: - # index error here means arg was passed implicitly - break - return function(*arguments, **attr) - _f.__doc__ = function.__doc__ - return _f - -def _validate_dsm_init_args(incoming_graph_data): - if isinstance(incoming_graph_data, list): - assert all(isinstance(n[0], DynamicStructureNode) and isinstance(n[1], DynamicStructureNode) for n in incoming_graph_data) + multi_node = check_collection_type(arg) + if multi_node: + if isinstance(arg, types.GeneratorType): + arg = list(arg) + if all(isinstance(n, DynamicStructureNode) for n in arg): + return arg + else: + new_arg = [] + for n in arg: + if isinstance(n, DynamicStructureNode): + new_arg.append(n) + elif isinstance(n, Tuple) and len(n) == 2: + new_arg.append(DynamicStructureNode(n[0], n[1])) + else: + raise TypeError(f'Argument {n} must be either a DynamicStructureNode or tuple containing 2 elements') + return new_arg + else: + if isinstance(arg, DynamicStructureNode): + return arg + elif isinstance(arg, Tuple) and len(arg) == 2: + return DynamicStructureNode(arg[0], arg[1]) + else: + raise TypeError(f'Argument {arg} must be either a DynamicStructureNode or tuple containing 2 elements') class DynamicStructureModel(StructureModel): """ - Base class for structure models, which are an extension of ``networkx.DiGraph``. + Base class for dynamic structure models, which are an extension of ``StructureModel``. - A ``StructureModel`` stores nodes and edges with optional data, or attributes. + A ``DynamicStructureModel`` stores ``DynamicStructureNode``s and edges with optional data, or attributes. Edges have one required attribute, "origin", which describes how the edge was created. Origin can be one of either unknown, learned, or expert. - StructureModel hold directed edges, describing a cause -> effect relationship. - Cycles are permitted within a ``StructureModel``. + DynamicStructureModel hold directed edges, describing a cause -> effect relationship. + Cycles are permitted within a ``DynamicStructureModel``. - Nodes can be arbitrary (hashable) Python objects with optional key/value attributes. + Nodes will be coerced into ``DynamicStructureNode``s with optional key/value attributes. By convention None is not used as a node. Edges are represented as links between nodes with optional key/value attributes. @@ -405,7 +464,7 @@ class DynamicStructureModel(StructureModel): def __init__(self, incoming_graph_data=None, origin="unknown", **attr): """ - Create a ``StructureModel`` with incoming_graph_data, which has come from some origin. + Create a ``DynamicStructureModel`` with incoming_graph_data, which has come from some origin. Args: incoming_graph_data (Optional): input graph (optional, default: None) @@ -421,19 +480,16 @@ def __init__(self, incoming_graph_data=None, origin="unknown", **attr): attr : Attributes to add to graph as key/value pairs (no attributes by default). """ - if incoming_graph_data is not None: - _validate_dsm_init_args(incoming_graph_data) super().__init__(incoming_graph_data, origin, **attr) - @checkargs def add_node(self, dnode: DynamicStructureNode): - super().add_nodes_from([dnode.get_node_name()]) + dnode = coerce_dsm_nodes(dnode) + super().add_node(dnode) - @checkargs def add_nodes(self, dnodes: List[DynamicStructureNode]): - node_names = [dnode.get_node_name() for dnode in dnodes] - super().add_nodes_from(node_names) + dnodes = coerce_dsm_nodes(dnodes) + super().add_nodes_from(dnodes) def to_directed_class(self): """ @@ -442,7 +498,6 @@ def to_directed_class(self): """ return DynamicStructureModel - @checkargs def get_target_subgraph(self, node: DynamicStructureNode) -> "DynamicStructureModel": """ Get the subgraph with the specified node. @@ -456,17 +511,9 @@ def get_target_subgraph(self, node: DynamicStructureNode) -> "DynamicStructureMo Raises: NodeNotFound: if the node is not found in the graph. """ - node_name = node.get_node_name() - if node_name in self.nodes: - for component in nx.weakly_connected_components(self): - subgraph = self.subgraph(component).copy() - - if node_name in set(subgraph.nodes): - return subgraph - - raise NodeNotFound(f"Node {node} not found in the graph") + node = coerce_dsm_nodes(node) + return super().get_target_subgraph(node) - @checkargs def get_markov_blanket( self, nodes: Union[DynamicStructureNode, List[DynamicStructureNode], Set[DynamicStructureNode]] ) -> "DynamicStructureModel": @@ -482,39 +529,17 @@ def get_markov_blanket( Raises: NodeNotFound: if one of the target nodes is not found in the graph. """ - if not isinstance(nodes, (list, set)): - nodes = [nodes] - - blanket_nodes = set() - - for node in set(nodes): # Ensure target nodes are unique - node_name = node.get_node_name() - if node_name not in set(self.nodes): - raise NodeNotFound(f"Node {node} not found in the graph") - - blanket_nodes.add(node_name) - blanket_nodes.update(self.predecessors(node_name)) + nodes = coerce_dsm_nodes(nodes) + return super().get_markov_blanket(nodes, DynamicStructureModel) - for child in self.successors(node_name): - blanket_nodes.add(child) - blanket_nodes.update(self.predecessors(child)) - - blanket = DynamicStructureModel() - blanket_dyn_nodes = [DynamicStructureNode(node_name[0], node_name[-1]) for node_name in blanket_nodes] - blanket.add_nodes(blanket_dyn_nodes) - - blanket_weighted_edges = [] - for u, v, w in self.edges(data="weight"): - if u in blanket_nodes and v in blanket_nodes: - blanket_weighted_edges.append((DynamicStructureNode(u[0], u[-1]), DynamicStructureNode(v[0], v[-1]), w)) - blanket.add_weighted_edges_from(blanket_weighted_edges) - return blanket + def add_edge(self, u: DynamicStructureNode, v: DynamicStructureNode, origin: str = "unknown", **attr): + edge = coerce_dsm_edges((u, v)) + super().add_edge(edge[0], edge[1], origin, **attr) # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) # this has been disabled because origin tracking is required for CausalGraphs # implementing it in this way allows all 3rd party libraries and applications to # integrate seamlessly, where edges will be given origin="unknown" where not provided - @checkargs def add_edges_from( self, ebunch_to_add: Union[Set[Tuple[DynamicStructureNode, DynamicStructureNode]], List[Tuple[DynamicStructureNode, DynamicStructureNode]]], @@ -524,7 +549,7 @@ def add_edges_from( """ Adds a bunch of causal relationships, u -> v. - If u or v do not currently exists in the ``StructureModel`` then they will be created. + If u or v do not currently exists in the ``DynamicStructureModel`` then they will be created. By default relationships will be given origin="unknown", but may also be given "learned" or "expert" origin. @@ -544,32 +569,14 @@ def add_edges_from( - expert: edges were created by a domain expert. **attr: Attributes to add to edge as key/value pairs (no attributes by default). """ - _validate_origin(origin) - - if isinstance(ebunch_to_add, Iterable) and not ebunch_to_add: - super().add_edges_from(ebunch_to_add, **attr) - return - - if isinstance(ebunch_to_add, types.GeneratorType): - dsn_ebunch = [] - for e in ebunch_to_add: - if len(e) == 3: - dsn_ebunch.append((DynamicStructureNode(e[0][0], e[0][-1]).get_node_name(), DynamicStructureNode(e[1][0], e[1][-1]).get_node_name(), e[2])) - else: - dsn_ebunch.append((DynamicStructureNode(e[0][0], e[0][-1]).get_node_name(), DynamicStructureNode(e[1][0], e[1][-1]).get_node_name())) - else: - if len(ebunch_to_add[0]) == 3: - dsn_ebunch = [(e[0].get_node_name(), e[1].get_node_name(), e[2]) for e in ebunch_to_add] - else: - dsn_ebunch = [(e[0].get_node_name(), e[1].get_node_name()) for e in ebunch_to_add] - attr.update({"origin": origin}) - super().add_edges_from(dsn_ebunch, **attr) + ebunch_to_add = coerce_dsm_edges(ebunch_to_add) + print(ebunch_to_add) + super().add_edges_from(ebunch_to_add, origin, **attr) # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) # this has been disabled because origin tracking is required for CausalGraphs # implementing it in this way allows all 3rd party libraries and applications to # integrate seamlessly, where edges will be given origin="unknown" where not provided - @checkargs def add_weighted_edges_from( self, ebunch_to_add: Union[Set[Tuple[DynamicStructureNode, DynamicStructureNode, float]], List[Tuple[DynamicStructureNode, DynamicStructureNode, float]]], @@ -580,7 +587,7 @@ def add_weighted_edges_from( """ Adds a bunch of weighted causal relationships, u -> v. - If u or v do not currently exists in the ``StructureModel`` then they will be created. + If u or v do not currently exists in the ``DynamicStructureModel`` then they will be created. By default relationships will be given origin="unknown", but may also be given "learned" or "expert" origin. @@ -602,15 +609,7 @@ def add_weighted_edges_from( - expert: edges were created by a domain expert. **attr: Attributes to add to edge as key/value pairs (no attributes by default). """ - _validate_origin(origin) - - if isinstance(ebunch_to_add, Iterable) and not ebunch_to_add: - super().add_weighted_edges_from(ebunch_to_add, weight=weight, **attr) - return - - if isinstance(ebunch_to_add, types.GeneratorType): - dsn_ebunch = [(DynamicStructureNode(e[0][0], e[0][-1]).get_node_name(), DynamicStructureNode(e[1][0], e[1][-1]).get_node_name(), e[2]) for e in ebunch_to_add] - else: - dsn_ebunch = [(e[0].get_node_name(), e[1].get_node_name(), e[2]) for e in ebunch_to_add] - attr.update({"origin": origin}) - super().add_weighted_edges_from(dsn_ebunch, weight=weight, **attr) \ No newline at end of file + ebunch_to_add = coerce_dsm_edges(ebunch_to_add) + if not isinstance(ebunch_to_add, list): + ebunch_to_add = [ebunch_to_add] + super().add_weighted_edges_from(ebunch_to_add, weight=weight, origin=origin, **attr) \ No newline at end of file diff --git a/tests/structure/test_dynamicstructuremodel.py b/tests/structure/test_dynamicstructuremodel.py index 1910617..8cf3f84 100644 --- a/tests/structure/test_dynamicstructuremodel.py +++ b/tests/structure/test_dynamicstructuremodel.py @@ -37,45 +37,45 @@ def test_init_has_origin(self): """Creating a DynamicStructureModel using constructor should give all edges unknown origin""" nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] sm = DynamicStructureModel([(nodes[0], nodes[1])]) - assert (nodes[0].get_node_name(), nodes[1].get_node_name()) in sm.edges - assert (nodes[0].get_node_name(), nodes[1].get_node_name(), "unknown") in sm.edges.data("origin") + assert (nodes[0], nodes[1]) in sm.edges + assert (nodes[0], nodes[1], "unknown") in sm.edges.data("origin") def test_init_with_origin(self): """should be possible to specify origin during init""" nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] sm = DynamicStructureModel([(nodes[0], nodes[1])], origin="learned") - assert (nodes[0].get_node_name(), nodes[1].get_node_name(), "learned") in sm.edges.data("origin") + assert (nodes[0], nodes[1], "learned") in sm.edges.data("origin") def test_edge_unknown_property(self): """should return only edges whose origin is unknown""" sm = DynamicStructureModel() - sm.add_edge(1, 2, origin="unknown") - sm.add_edge(1, 3, origin="learned") - sm.add_edge(1, 4, origin="expert") + sm.add_edge((1, 0), (2, 0), origin="unknown") + sm.add_edge((1, 0), (3, 0), origin="learned") + sm.add_edge((1, 0), (4, 0), origin="expert") - assert sm.edges_with_origin("unknown") == [(1, 2)] + assert sm.edges_with_origin("unknown") == [(DynamicStructureNode(1, 0), DynamicStructureNode(2, 0))] def test_edge_learned_property(self): """should return only edges whose origin is unknown""" sm = DynamicStructureModel() - sm.add_edge(1, 2, origin="unknown") - sm.add_edge(1, 3, origin="learned") - sm.add_edge(1, 4, origin="expert") + sm.add_edge((1, 0), (2, 0), origin="unknown") + sm.add_edge((1, 0), (3, 0), origin="learned") + sm.add_edge((1, 0), (4, 0), origin="expert") - assert sm.edges_with_origin("learned") == [(1, 3)] + assert sm.edges_with_origin("learned") == [(DynamicStructureNode(1, 0), DynamicStructureNode(3, 0))] def test_edge_expert_property(self): """should return only edges whose origin is unknown""" sm = DynamicStructureModel() - sm.add_edge(1, 2, origin="unknown") - sm.add_edge(1, 3, origin="learned") - sm.add_edge(1, 4, origin="expert") + sm.add_edge((1, 0), (2, 0), origin="unknown") + sm.add_edge((1, 0), (3, 0), origin="learned") + sm.add_edge((1, 0), (4, 0), origin="expert") - assert sm.edges_with_origin("expert") == [(1, 4)] + assert sm.edges_with_origin("expert") == [(DynamicStructureNode(1, 0), DynamicStructureNode(4, 0))] def test_to_directed(self): """should create a structure model""" @@ -85,10 +85,11 @@ def test_to_directed(self): edges = [(nodes[0], nodes[1]), (nodes[1], nodes[0]), (nodes[1], nodes[2]), (nodes[2], nodes[3])] sm.add_edges_from(edges) - + dag = sm.to_directed() + assert isinstance(dag, DynamicStructureModel) - assert all((edge[0].get_node_name(), edge[1].get_node_name()) in dag.edges for edge in edges) + assert all((edge[0], edge[1]) in dag.edges for edge in edges) def test_to_undirected(self): """should create an undirected Graph""" @@ -101,8 +102,8 @@ def test_to_undirected(self): udg = sm.to_undirected() - assert all((edge[0].get_node_name(), edge[1].get_node_name()) in udg.edges for edge in [(nodes[1], nodes[2]), (nodes[2], nodes[3])]) - assert (nodes[0].get_node_name(), nodes[1].get_node_name()) in udg.edges or (nodes[1].get_node_name(), nodes[0].get_node_name()) in udg.edges + assert all((edge[0], edge[1]) in udg.edges for edge in [(nodes[1], nodes[2]), (nodes[2], nodes[3])]) + assert (nodes[0], nodes[1]) in udg.edges or (nodes[1], nodes[0]) in udg.edges assert len(udg.edges) == 3 @@ -111,37 +112,37 @@ def test_add_edge_default(self): """edges added with default origin should be identified as unknown origin""" sm = DynamicStructureModel() - sm.add_edge(1, 2) + sm.add_edge((1, 0), (2, 0)) - assert (1, 2) in sm.edges - assert (1, 2, "unknown") in sm.edges.data("origin") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "unknown") in sm.edges.data("origin") def test_add_edge_unknown(self): """edges added with unknown origin should be labelled as unknown origin""" sm = DynamicStructureModel() - sm.add_edge(1, 2, "unknown") + sm.add_edge((1, 0), (2, 0), "unknown") - assert (1, 2) in sm.edges - assert (1, 2, "unknown") in sm.edges.data("origin") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "unknown") in sm.edges.data("origin") def test_add_edge_learned(self): """edges added with learned origin should be labelled as learned origin""" sm = DynamicStructureModel() - sm.add_edge(1, 2, "learned") + sm.add_edge((1, 0), (2, 0), "learned") - assert (1, 2) in sm.edges - assert (1, 2, "learned") in sm.edges.data("origin") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "learned") in sm.edges.data("origin") def test_add_edge_expert(self): """edges added with expert origin should be labelled as expert origin""" sm = DynamicStructureModel() - sm.add_edge(1, 2, "expert") + sm.add_edge((1, 0), (2, 0), "expert") - assert (1, 2) in sm.edges - assert (1, 2, "expert") in sm.edges.data("origin") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "expert") in sm.edges.data("origin") def test_add_edge_other(self): """edges added with other origin should throw an error""" @@ -149,37 +150,37 @@ def test_add_edge_other(self): sm = DynamicStructureModel() with pytest.raises(ValueError, match="^Unknown origin: must be one of.*$"): - sm.add_edge(1, 2, "other") + sm.add_edge((1, 0), (2, 0), "other") def test_add_edge_custom_attr(self): """it should be possible to add an edge with custom attributes""" sm = DynamicStructureModel() - sm.add_edge(1, 2, x="Y") + sm.add_edge((1, 0), (2, 0), x="Y") - assert (1, 2) in sm.edges - assert (1, 2, "Y") in sm.edges.data("x") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "Y") in sm.edges.data("x") def test_add_edge_multiple_times(self): """adding an edge again should update the edges origin attr""" sm = DynamicStructureModel() - sm.add_edge(1, 2, origin="unknown") - assert (1, 2, "unknown") in sm.edges.data("origin") - sm.add_edge(1, 2, origin="learned") - assert (1, 2, "learned") in sm.edges.data("origin") + sm.add_edge((1, 0), (2, 0), origin="unknown") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "unknown") in sm.edges.data("origin") + sm.add_edge((1, 0), (2, 0), origin="learned") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "learned") in sm.edges.data("origin") def test_add_multiple_edges(self): """it should be possible to add multiple edges with different origins""" sm = DynamicStructureModel() - sm.add_edge(1, 2, origin="unknown") - sm.add_edge(1, 3, origin="learned") - sm.add_edge(1, 4, origin="expert") + sm.add_edge((1, 0), (2, 0), origin="unknown") + sm.add_edge((1, 0), (3, 0), origin="learned") + sm.add_edge((1, 0), (4, 0), origin="expert") - assert (1, 2, "unknown") in sm.edges.data("origin") - assert (1, 3, "learned") in sm.edges.data("origin") - assert (1, 4, "expert") in sm.edges.data("origin") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "unknown") in sm.edges.data("origin") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0), "learned") in sm.edges.data("origin") + assert (DynamicStructureNode(1, 0), DynamicStructureNode(4, 0), "expert") in sm.edges.data("origin") class TestDynamicStructureModelAddEdgesFrom: @@ -191,8 +192,8 @@ def test_add_edges_from_default(self): edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges) - assert all((edge[0].get_node_name(), edge[1].get_node_name()) in sm.edges for edge in edges) - assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v in edges) + assert all((edge[0], edge[1]) in sm.edges for edge in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v in edges) def test_add_edges_from_unknown(self): """edges added with unknown origin should be labelled as unknown origin""" @@ -203,8 +204,8 @@ def test_add_edges_from_unknown(self): edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, "unknown") - assert all((u.get_node_name(), v.get_node_name()) in sm.edges for u, v in edges) - assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v in edges) + assert all((u, v) in sm.edges for u, v in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v in edges) def test_add_edges_from_learned(self): """edges added with learned origin should be labelled as learned origin""" @@ -215,8 +216,8 @@ def test_add_edges_from_learned(self): edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, "learned") - assert all((u.get_node_name(), v.get_node_name()) in sm.edges for u, v in edges) - assert all((u.get_node_name(), v.get_node_name(), "learned") in sm.edges.data("origin") for u, v in edges) + assert all((u, v) in sm.edges for u, v in edges) + assert all((u, v, "learned") in sm.edges.data("origin") for u, v in edges) def test_add_edges_from_expert(self): """edges added with expert origin should be labelled as expert origin""" @@ -227,8 +228,8 @@ def test_add_edges_from_expert(self): edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, "expert") - assert all((u.get_node_name(), v.get_node_name()) in sm.edges for u, v in edges) - assert all((u.get_node_name(), v.get_node_name(), "expert") in sm.edges.data("origin") for u, v in edges) + assert all((u, v) in sm.edges for u, v in edges) + assert all((u, v, "expert") in sm.edges.data("origin") for u, v in edges) def test_add_edges_from_other(self): """edges added with other origin should throw an error""" @@ -248,8 +249,8 @@ def test_add_edges_from_custom_attr(self): edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, x="Y") - assert all((u.get_node_name(), v.get_node_name()) in sm.edges for u, v in edges) - assert all((u.get_node_name(), v.get_node_name(), "Y") in sm.edges.data("x") for u, v in edges) + assert all((u, v) in sm.edges for u, v in edges) + assert all((u, v, "Y") in sm.edges.data("x") for u, v in edges) def test_add_edges_from_multiple_times(self): """adding edges again should update the edges origin attr""" @@ -259,9 +260,9 @@ def test_add_edges_from_multiple_times(self): edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, "unknown") - assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v in edges) sm.add_edges_from(edges, "learned") - assert all((u.get_node_name(), v.get_node_name(), "learned") in sm.edges.data("origin") for u, v in edges) + assert all((u, v, "learned") in sm.edges.data("origin") for u, v in edges) def test_add_multiple_edges(self): """it should be possible to add multiple edges with different origins""" @@ -272,9 +273,9 @@ def test_add_multiple_edges(self): sm.add_edges_from([(nodes[0], nodes[2])], origin="learned") sm.add_edges_from([(nodes[0], nodes[3])], origin="expert") - assert (nodes[0].get_node_name(), nodes[1].get_node_name(), "unknown") in sm.edges.data("origin") - assert (nodes[0].get_node_name(), nodes[2].get_node_name(), "learned") in sm.edges.data("origin") - assert (nodes[0].get_node_name(), nodes[3].get_node_name(), "expert") in sm.edges.data("origin") + assert (nodes[0], nodes[1], "unknown") in sm.edges.data("origin") + assert (nodes[0], nodes[2], "learned") in sm.edges.data("origin") + assert (nodes[0], nodes[3], "expert") in sm.edges.data("origin") class TestDynamicStructureModelAddWeightedEdgesFrom: @@ -286,8 +287,8 @@ def test_add_weighted_edges_from_default(self): edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] sm.add_weighted_edges_from(edges) - assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) - assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v, w in edges) + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v, w in edges) def test_add_weighted_edges_from_unknown(self): """edges added with unknown origin should be labelled as unknown origin""" @@ -298,8 +299,8 @@ def test_add_weighted_edges_from_unknown(self): edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] sm.add_weighted_edges_from(edges, origin="unknown") - assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) - assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v, w in edges) + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v, w in edges) def test_add_weighted_edges_from_learned(self): """edges added with learned origin should be labelled as learned origin""" @@ -310,8 +311,8 @@ def test_add_weighted_edges_from_learned(self): edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] sm.add_weighted_edges_from(edges, origin="learned") - assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) - assert all((u.get_node_name(), v.get_node_name(), "learned") in sm.edges.data("origin") for u, v, w in edges) + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "learned") in sm.edges.data("origin") for u, v, w in edges) def test_add_weighted_edges_from_expert(self): """edges added with expert origin should be labelled as expert origin""" @@ -323,8 +324,8 @@ def test_add_weighted_edges_from_expert(self): edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] sm.add_weighted_edges_from(edges, origin="expert") - assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) - assert all((u.get_node_name(), v.get_node_name(), "expert") in sm.edges.data("origin") for u, v, w in edges) + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "expert") in sm.edges.data("origin") for u, v, w in edges) def test_add_weighted_edges_from_other(self): """edges added with other origin should throw an error""" @@ -345,8 +346,8 @@ def test_add_weighted_edges_from_custom_attr(self): edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] sm.add_weighted_edges_from(edges, x="Y") - assert all((u.get_node_name(), v.get_node_name(), w) in sm.edges.data("weight") for u, v, w in edges) - assert all((u.get_node_name(), v.get_node_name(), "Y") in sm.edges.data("x") for u, v, _ in edges) + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "Y") in sm.edges.data("x") for u, v, _ in edges) def test_add_weighted_edges_from_multiple_times(self): """adding edges again should update the edges origin attr""" @@ -358,10 +359,10 @@ def test_add_weighted_edges_from_multiple_times(self): edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] sm.add_weighted_edges_from(edges, origin="unknown") - assert all((u.get_node_name(), v.get_node_name(), "unknown") in sm.edges.data("origin") for u, v, _ in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v, _ in edges) sm.add_weighted_edges_from(edges, origin="learned") - assert all((u.get_node_name(), v.get_node_name(), "learned") in sm.edges.data("origin") for u, v, _ in edges) + assert all((u, v, "learned") in sm.edges.data("origin") for u, v, _ in edges) def test_add_multiple_weighted_edges(self): """it should be possible to add multiple edges with different origins""" @@ -372,9 +373,9 @@ def test_add_multiple_weighted_edges(self): sm.add_weighted_edges_from([(nodes[0], nodes[2], 0.5)], origin="learned") sm.add_weighted_edges_from([(nodes[0], nodes[3], 0.5)], origin="expert") - assert ('1_lag0', '2_lag0', "unknown") in sm.edges.data("origin") - assert ('1_lag0', '3_lag0', "learned") in sm.edges.data("origin") - assert ('1_lag0', '4_lag0', "expert") in sm.edges.data("origin") + assert (nodes[0], nodes[1], "unknown") in sm.edges.data("origin") + assert (nodes[0], nodes[2], "learned") in sm.edges.data("origin") + assert (nodes[0], nodes[3], "expert") in sm.edges.data("origin") class TestDynamicStructureModelRemoveEdgesBelowThreshold: @@ -390,7 +391,7 @@ def test_remove_edges_below_threshold(self): sm.add_weighted_edges_from(weak_edges) sm.remove_edges_below_threshold(0.7) - assert set(sm.edges(data="weight")) == set((u.get_node_name(), v.get_node_name(), w) for u, v, w in strong_edges) + assert set(sm.edges(data="weight")) == set((u, v, w) for u, v, w in strong_edges) def test_negative_weights(self): """Negative edges whose absolute value is greater than the defined threshold should not be removed""" @@ -406,7 +407,7 @@ def test_negative_weights(self): sm.remove_edges_below_threshold(0.7) - assert set(sm.edges(data="weight")) == set((u.get_node_name(), v.get_node_name(), w) for u, v, w in strong_edges) + assert set(sm.edges(data="weight")) == set((u, v, w) for u, v, w in strong_edges) def test_equal_weights(self): """Edges whose absolute value is equal to the defined threshold should not be removed""" @@ -424,8 +425,8 @@ def test_equal_weights(self): sm.remove_edges_below_threshold(0.6) assert set(sm.edges(data="weight")) == set.union( - set((u.get_node_name(), v.get_node_name(), w) for u, v, w in strong_edges), - set((u.get_node_name(), v.get_node_name(), w) for u, v, w in equal_edges) + set((u, v, w) for u, v, w in strong_edges), + set((u, v, w) for u, v, w in equal_edges) ) def test_graph_with_no_edges(self): @@ -437,7 +438,7 @@ def test_graph_with_no_edges(self): sm.add_nodes(nodes) sm.remove_edges_below_threshold(0.6) - assert set(sm.nodes) == set([node.get_node_name() for node in nodes]) + assert set(sm.nodes) == set([node for node in nodes]) assert set(sm.edges) == set() @@ -525,10 +526,10 @@ def test_different_origins_and_weights(self): largest_subgraph = sm.get_largest_subgraph() assert set(largest_subgraph.edges.data("origin")) == { - (nodes[0].get_node_name(), nodes[1].get_node_name(), "unknown"), - (nodes[0].get_node_name(), nodes[2].get_node_name(), "learned"), + (nodes[0], nodes[1], "unknown"), + (nodes[0], nodes[2], "learned"), } - assert set(largest_subgraph.edges.data("weight")) == {(nodes[0].get_node_name(), nodes[1].get_node_name(), 2.0), (nodes[0].get_node_name(), nodes[2].get_node_name(), 1.0)} + assert set(largest_subgraph.edges.data("weight")) == {(nodes[0], nodes[1], 2.0), (nodes[0], nodes[2], 1.0)} class TestDynamicStructureModelGetTargetSubgraph: @@ -614,7 +615,8 @@ def test_isolates(self): subgraph = sm.get_target_subgraph(DynamicStructureNode(1, 0)) expected_graph = DynamicStructureModel() expected_graph.add_node(DynamicStructureNode(1, 0)) - + print(f'subgraph nodes {subgraph.nodes}\n') + print(f'expected nodes {expected_graph.nodes}') assert set(subgraph.nodes) == set(expected_graph.nodes) assert set(subgraph.edges) == set(expected_graph.edges) @@ -648,10 +650,10 @@ def test_different_origins_and_weights(self): subgraph = sm.get_target_subgraph(nodes[1]) assert set(subgraph.edges.data("origin")) == { - (nodes[0].get_node_name(), nodes[1].get_node_name(), "unknown"), - (nodes[0].get_node_name(), nodes[2].get_node_name(), "learned"), + (nodes[0], nodes[1], "unknown"), + (nodes[0], nodes[2], "learned"), } - assert set(subgraph.edges.data("weight")) == {(nodes[0].get_node_name(), nodes[1].get_node_name(), 2.0), (nodes[0].get_node_name(), nodes[2].get_node_name(), 1.0)} + assert set(subgraph.edges.data("weight")) == {(nodes[0], nodes[1], 2.0), (nodes[0], nodes[2], 1.0)} def test_instance_type(self): """The subgraph returned should still be a DynamicStructureModel instance""" @@ -671,7 +673,7 @@ def test_get_target_subgraph_twice(self): sm.add_edges_from([(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[4], nodes[5])]) subgraph = sm.get_target_subgraph(nodes[0]) - subgraph.remove_edge(nodes[0].get_node_name(), nodes[1].get_node_name()) + subgraph.remove_edge(nodes[0], nodes[1]) subgraph = subgraph.get_target_subgraph(nodes[1]) expected_graph = DynamicStructureModel() @@ -824,3 +826,270 @@ def test_instance_type(self): subgraph = sm.get_markov_blanket(nodes[2]) assert isinstance(subgraph, DynamicStructureModel) + +class TestDynamicStructureModelEdgeCoercion: + + def test_edge_not_tuple(self): + edges = [((1, 0), (3, 0), .5), 6] + sm = DynamicStructureModel() + + with pytest.raises( + TypeError, + match=re.escape(f"Edges must be tuples containing 2 or 3 elements, received {edges}"), + ): + sm.add_edges_from(edges) + + def test_multi_edge_not_dsn(self): + edges = [((0,0), (1,0)), ((1,0), (2,0))] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + + expected_edges = [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), (DynamicStructureNode(1,0), DynamicStructureNode(2,0))] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + def test_weighted_multi_edge_not_dsn(self): + edges = [((0,0), (1,0), .5), ((1,0), (2,0), .7)] + sm = DynamicStructureModel() + sm.add_weighted_edges_from(edges) + + expected_edges = [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .7)] + expected_graph = DynamicStructureModel() + expected_graph.add_weighted_edges_from(expected_edges) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "input_edges, expected_edges", + [ + ( + [((0,0), (1,0), .5), ((1,0), (2,0), .7)], + [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .7)] + ), + ( + [((0,0), (1,0)), ((1,0), (2,0))], + [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), (DynamicStructureNode(1,0), DynamicStructureNode(2,0))] + ) + ], + ) + def test_multi_edge_dsn(self, input_edges, expected_edges): + sm = DynamicStructureModel() + weighted = len(input_edges[0]) == 3 + if not weighted: + sm.add_edges_from(input_edges) + else: + sm.add_weighted_edges_from(input_edges) + + expected_graph = DynamicStructureModel() + if not weighted: + expected_graph.add_edges_from(expected_edges) + else: + expected_graph.add_weighted_edges_from(expected_edges) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + def test_node_not_tuple(self): + edges = [((1, 0), (3, 0), .5), ((1, 0), 3, .7)] + sm = DynamicStructureModel() + + with pytest.raises( + TypeError, + match=re.escape(f"Nodes in {edges[1]} must be tuples with node name and time step"), + ): + sm.add_edges_from(edges) + + @pytest.mark.parametrize( + "input_edges, expected_edges", + [ + ( + [(DynamicStructureNode(0,0), (1,0), .5), ((1,0), DynamicStructureNode(2,0), .7)], + [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .7)] + ), + ( + [(DynamicStructureNode(0,0), (1,0)), ((1,0), DynamicStructureNode(2,0))], + [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), (DynamicStructureNode(1,0), DynamicStructureNode(2,0))] + ), + ( + [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), ((1,0), DynamicStructureNode(2,0))], + [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), (DynamicStructureNode(1,0), DynamicStructureNode(2,0))] + ) + ], + ) + def test_multi_edge_one_dsn(self, input_edges, expected_edges): + sm = DynamicStructureModel() + weighted = len(input_edges[0]) == 3 + if not weighted: + sm.add_edges_from(input_edges) + else: + sm.add_weighted_edges_from(input_edges) + + expected_graph = DynamicStructureModel() + if not weighted: + expected_graph.add_edges_from(expected_edges) + else: + expected_graph.add_weighted_edges_from(expected_edges) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + def test_multi_edge_bad_tuple(self): + edges = [((0,0), (1,0), .5), ((1,0), (2,0), .7, .8)] + sm = DynamicStructureModel() + with pytest.raises( + TypeError, + match=re.escape(f"Argument {edges[1]} must be a tuple containing 2 or 3 elements"), + ): + sm.add_weighted_edges_from(edges) + + # def test_single_edge_not_tuple(self): + # edge = 6 + # sm = DynamicStructureModel() + + # with pytest.raises( + # TypeError, + # match=re.escape(f"Edges must be tuples containing 2 or 3 elements, received {edge}"), + # ): + # sm.add_edge(edge) + + def test_single_edge_node_not_tuple(self): + u = (1, 0) + v = 3 + sm = DynamicStructureModel() + + with pytest.raises( + TypeError, + match=re.escape(f"Nodes in {(u, v)} must be tuples with node name and time step"), + ): + sm.add_edge(u, v) + + @pytest.mark.parametrize( + "input_edge, expected_edge", + [ + ( + ((0,0), (1,0), .5), + (DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5) + ), + ( + ((0,0), (1,0)), + (DynamicStructureNode(0,0), DynamicStructureNode(1,0)) + ), + ( + (DynamicStructureNode(0,0), DynamicStructureNode(1,0)), + (DynamicStructureNode(0,0), DynamicStructureNode(1,0)) + ), + ( + (DynamicStructureNode(0,0), (1,0)), + (DynamicStructureNode(0,0), DynamicStructureNode(1,0)) + ), + ( + ((0,0), DynamicStructureNode(1,0)), + (DynamicStructureNode(0,0), DynamicStructureNode(1,0)) + ), + ( + (DynamicStructureNode(0,0), (1,0), .5), + (DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5) + ), + ( + ((0,0), DynamicStructureNode(1,0), .5), + (DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5) + ) + ], + ) + def test_single_edge_dsn(self, input_edge, expected_edge): + sm = DynamicStructureModel() + weighted = len(input_edge) == 3 + if not weighted: + sm.add_edge(input_edge[0], input_edge[1]) + else: + sm.add_weighted_edges_from(input_edge) + + expected_graph = DynamicStructureModel() + if not weighted: + expected_graph.add_edge(expected_edge[0], expected_edge[1]) + else: + expected_graph.add_weighted_edges_from(expected_edge) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + def test_single_edge_bad_tuple(self): + edge = ((1,0), (2,0), .7, .8) + sm = DynamicStructureModel() + with pytest.raises( + TypeError, + match=re.escape(f"Argument {edge} must be either a DynamicStructureNode or tuple containing 2 or 3 elements"), + ): + sm.add_weighted_edges_from(edge) + +class TestDynamicStructureModelNodeCoercion: + + @pytest.mark.parametrize( + "input_nodes, expected_nodes", + [ + ( + [(0,0), (1,0)], + [DynamicStructureNode(0,0), DynamicStructureNode(1,0)] + ), + ( + [DynamicStructureNode(0,0), (1,0)], + [DynamicStructureNode(0,0), DynamicStructureNode(1,0)] + ), + ( + (DynamicStructureNode(n, 0) for n in range(2)), + [DynamicStructureNode(0,0), DynamicStructureNode(1,0)] + ) + ], + ) + def test_multi_node(self, input_nodes, expected_nodes): + sm = DynamicStructureModel() + sm.add_nodes(input_nodes) + + expected_graph = DynamicStructureModel() + expected_graph.add_nodes(expected_nodes) + + assert set(sm.nodes) == set(expected_graph.nodes) + + @pytest.mark.parametrize( + "input_node, expected_node", + [ + ( + (0,0), + DynamicStructureNode(0,0) + ), + ( + DynamicStructureNode(0,0), + DynamicStructureNode(0,0) + ) + ], + ) + def test_single_node(self, input_node, expected_node): + sm = DynamicStructureModel() + sm.add_nodes(input_node) + + expected_graph = DynamicStructureModel() + expected_graph.add_nodes(expected_node) + + assert set(sm.nodes) == set(expected_graph.nodes) + + def test_multi_node_bad_tuple(self): + nodes = [(0,0), (1,0,1)] + sm = DynamicStructureModel() + with pytest.raises( + TypeError, + match=re.escape(f"Argument {nodes[1]} must be either a DynamicStructureNode or tuple containing 2 elements"), + ): + sm.add_nodes(nodes) + + def test_single_node_bad_tuple(self): + node = (1,0,1) + sm = DynamicStructureModel() + with pytest.raises( + TypeError, + match=re.escape(f"Argument {node} must be either a DynamicStructureNode or tuple containing 2 elements"), + ): + sm.add_node(node) \ No newline at end of file diff --git a/tests/structure/test_dynotears.py b/tests/structure/test_dynotears.py index 9ce5771..3aa3c58 100644 --- a/tests/structure/test_dynotears.py +++ b/tests/structure/test_dynotears.py @@ -26,6 +26,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +from causalnex.structure.structuremodel import DynamicStructureNode import networkx as nx import numpy as np @@ -95,9 +96,10 @@ def test_naming_nodes(self, data_dynotears_p3): pattern = re.compile(r"[0-5]_lag[0-3]") for node in sm.nodes: - match = pattern.match(node) + node_name = node.get_node_name() + match = pattern.match(node_name) assert match - assert match.group() == node + assert match.group() == node_name def test_inter_edges(self, data_dynotears_p3): """ @@ -119,22 +121,22 @@ def test_expected_structure_learned_p1(self, data_dynotears_p1): data_dynotears_p1["X"], data_dynotears_p1["Y"], w_threshold=0.2 ) w_edges = [ - (f"{i}_lag0", f"{j}_lag0") + (DynamicStructureNode(i, 0), DynamicStructureNode(j, 0)) for i in range(5) for j in range(5) if data_dynotears_p1["W"][i, j] != 0 ] a_edges = [ - (f"{i % 5}_lag{1 + i // 5}", f"{j}_lag0") + (DynamicStructureNode(i % 5, 1 + i // 5), DynamicStructureNode(j, 0)) for i in range(5) for j in range(5) if data_dynotears_p1["A"][i, j] != 0 ] edges_in_sm_and_a = [el for el in sm.edges if el in a_edges] - sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0]] + sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0].get_node_name()] - assert sorted([el for el in sm.edges if "lag0" in el[0]]) == sorted(w_edges) + assert sorted([el for el in sm.edges if "lag0" in el[0].get_node_name()]) == sorted(w_edges) assert len(edges_in_sm_and_a) / len(a_edges) > 0.6 assert len(edges_in_sm_and_a) / len(sm_inter_edges) > 0.9 @@ -148,21 +150,21 @@ def test_expected_structure_learned_p2(self, data_dynotears_p2): data_dynotears_p2["X"], data_dynotears_p2["Y"], w_threshold=0.25 ) w_edges = [ - (f"{i}_lag0", f"{j}_lag0") + (DynamicStructureNode(i, 0), DynamicStructureNode(j, 0)) for i in range(5) for j in range(5) if data_dynotears_p2["W"][i, j] != 0 ] a_edges = [ - (f"{i % 5}_lag{1 + i // 5}", f"{j}_lag0") + (DynamicStructureNode(i % 5, 1 + i // 5), DynamicStructureNode(j, 0)) for i in range(5) for j in range(5) if data_dynotears_p2["A"][i, j] != 0 ] edges_in_sm_and_a = [el for el in sm.edges if el in a_edges] - sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0]] - sm_intra_edges = [el for el in sm.edges if "lag0" in el[0]] + sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0].get_node_name()] + sm_intra_edges = [el for el in sm.edges if "lag0" in el[0].get_node_name()] assert len([el for el in sm_intra_edges if el not in w_edges]) == 0 assert ( @@ -244,7 +246,7 @@ def test_all_columns_in_structure(self, data_dynotears_p2): data_dynotears_p2["Y"], ) assert sorted(sm.nodes) == [ - f"{var}_lag{l_val}" for var in range(5) for l_val in range(3) + DynamicStructureNode(var, l_val) for var in range(5) for l_val in range(3) ] def test_isolated_nodes_exist(self, data_dynotears_p2): @@ -277,7 +279,8 @@ def test_certain_relationships_get_near_certain_weight(self): ) sm = from_numpy_dynamic(data.values[1:], data.values[:-1], w_threshold=0.1) edge = ( - sm.get_edge_data("1_lag0", "0_lag0") or sm.get_edge_data("0_lag0", "1_lag0") + sm.get_edge_data(DynamicStructureNode(1, 0), DynamicStructureNode(0, 0)) or + sm.get_edge_data(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)) )["weight"] assert 0.99 < edge <= 1.01 @@ -291,7 +294,8 @@ def test_inverse_relationships_get_negative_weight(self): ) sm = from_numpy_dynamic(data.values[1:], data.values[:-1], w_threshold=0.1) edge = ( - sm.get_edge_data("1_lag0", "0_lag0") or sm.get_edge_data("0_lag0", "1_lag0") + sm.get_edge_data(DynamicStructureNode(1, 0), DynamicStructureNode(0, 0)) or + sm.get_edge_data(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)) )["weight"] assert -1.01 < edge <= -0.99 @@ -365,9 +369,10 @@ def test_naming_nodes(self, data_dynotears_p3): pattern = re.compile(r"[abcde]_lag[0-3]") for node in sm.nodes: - match = pattern.match(node) + node_name = node.get_node_name() + match = pattern.match(node_name) assert match - assert match.group() == node + assert match.group() == node_name def test_inter_edges(self, data_dynotears_p3): """ @@ -397,15 +402,15 @@ def test_expected_structure_learned_p1(self, data_dynotears_p1): ) map_ = dict(zip(range(5), ["a", "b", "c", "d", "e"])) w_edges = [ - (f"{map_[i]}_lag0", f"{map_[j]}_lag0") + (DynamicStructureNode(map_[i], 0), DynamicStructureNode(map_[j], 0)) for i in range(5) for j in range(5) if data_dynotears_p1["W"][i, j] != 0 ] a_edges = [ ( - f"{map_[i % 5]}_lag{1 + i // 5}", - f"{map_[j]}_lag0", + DynamicStructureNode(map_[i % 5], 1 + i // 5), + DynamicStructureNode(map_[j], 0), ) for i in range(5) for j in range(5) @@ -413,8 +418,8 @@ def test_expected_structure_learned_p1(self, data_dynotears_p1): ] edges_in_sm_and_a = [el for el in sm.edges if el in a_edges] - sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0]] - assert sorted(el for el in sm.edges if "lag0" in el[0]) == sorted(w_edges) + sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0].get_node_name()] + assert sorted(el for el in sm.edges if "lag0" in el[0].get_node_name()) == sorted(w_edges) assert len(edges_in_sm_and_a) / len(a_edges) > 0.6 assert len(edges_in_sm_and_a) / len(sm_inter_edges) > 0.9 @@ -436,15 +441,15 @@ def test_expected_structure_learned_p2(self, data_dynotears_p2): ) map_ = dict(zip(range(5), ["a", "b", "c", "d", "e"])) w_edges = [ - (f"{map_[i]}_lag0", f"{map_[j]}_lag0") + (DynamicStructureNode(map_[i], 0), DynamicStructureNode(map_[j], 0)) for i in range(5) for j in range(5) if data_dynotears_p2["W"][i, j] != 0 ] a_edges = [ ( - f"{map_[i % 5]}_lag{1 + i // 5}", - f"{map_[j]}_lag0", + DynamicStructureNode(map_[i % 5], 1 + i // 5), + DynamicStructureNode(map_[j], 0), ) for i in range(5) for j in range(5) @@ -452,8 +457,8 @@ def test_expected_structure_learned_p2(self, data_dynotears_p2): ] edges_in_sm_and_a = [el for el in sm.edges if el in a_edges] - sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0]] - sm_intra_edges = [el for el in sm.edges if "lag0" in el[0]] + sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0].get_node_name()] + sm_intra_edges = [el for el in sm.edges if "lag0" in el[0].get_node_name()] assert len([el for el in sm_intra_edges if el not in w_edges]) == 0 assert ( @@ -537,7 +542,7 @@ def test_all_columns_in_structure(self, data_dynotears_p2): w_threshold=0.4, ) assert sorted(sm.nodes) == [ - f"{var}_lag{l_val}" + DynamicStructureNode(var, l_val) for var in ["a", "b", "c", "d", "e"] for l_val in range(3) ] @@ -576,7 +581,8 @@ def test_certain_relationships_get_near_certain_weight(self): ) sm = from_pandas_dynamic(data, p=1, w_threshold=0.1) edge = ( - sm.get_edge_data("b_lag0", "a_lag0") or sm.get_edge_data("a_lag0", "b_lag0") + sm.get_edge_data(DynamicStructureNode('b', 0), DynamicStructureNode('a', 0)) or + sm.get_edge_data(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)) )["weight"] assert 0.99 < edge <= 1.01 @@ -590,7 +596,8 @@ def test_inverse_relationships_get_negative_weight(self): ) sm = from_pandas_dynamic(data, p=1, w_threshold=0.1) edge = ( - sm.get_edge_data("b_lag0", "a_lag0") or sm.get_edge_data("a_lag0", "b_lag0") + sm.get_edge_data(DynamicStructureNode('b', 0), DynamicStructureNode('a', 0)) or + sm.get_edge_data(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)) )["weight"] assert -1.01 < edge <= -0.99 diff --git a/tests/test_dynamicbayesiannetwork.py b/tests/test_dynamicbayesiannetwork.py deleted file mode 100644 index a4ac757..0000000 --- a/tests/test_dynamicbayesiannetwork.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2019-2020 QuantumBlack Visual Analytics Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND -# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS -# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN -# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN -# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo -# (either separately or in combination, "QuantumBlack Trademarks") are -# trademarks of QuantumBlack. The License does not grant you any right or -# license to the QuantumBlack Trademarks. You may not use the QuantumBlack -# Trademarks or any confusingly similar mark as a trademark for your product, -# or use the QuantumBlack Trademarks in any other manner that might cause -# confusion in the marketplace, including but not limited to in advertising, -# on websites, or on software. -# -# See the License for the specific language governing permissions and -# limitations under the License. - -#from causalnex.structure.dynotears import from_numpy_dynamic, from_pandas_dynamic -#from causalnex.network import DynamicBayesianNetwork - -''' -functions to test in DBN are fit_node_states, fit_node_states_and_cpds, fit_latent_cpds, predict, predict_probability -only change to dynotears is using DSM instead of SM -in DSM, functionality is still provided by nx.DiGraph -main update will be to update DBN to use different model than pgmpy.models.BayesianModel -just run regression tests on test_dynotears and test_dynamicstructure_model -''' \ No newline at end of file From 326f745ce3092f389856ee14ee0cd54acae4f44d Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Mon, 17 Oct 2022 14:37:56 -0500 Subject: [PATCH 07/13] fix linting errors --- causalnex/network/network.py | 4 +- causalnex/structure/__init__.py | 4 +- causalnex/structure/dynotears.py | 11 +- causalnex/structure/structuremodel.py | 147 ++- tests/structure/test_dynamicstructuremodel.py | 966 ++++++++++++++---- tests/structure/test_dynotears.py | 30 +- 6 files changed, 903 insertions(+), 259 deletions(-) diff --git a/causalnex/network/network.py b/causalnex/network/network.py index 11ba824..908e55a 100644 --- a/causalnex/network/network.py +++ b/causalnex/network/network.py @@ -43,7 +43,7 @@ from pgmpy.models import BayesianModel from causalnex.estimator.em import EMSingleLatentVariable -from causalnex.structure import StructureModel, DynamicStructureModel +from causalnex.structure import StructureModel from causalnex.utils.pgmpy_utils import pd_to_tabular_cpd @@ -735,4 +735,4 @@ def _predict_probability_from_incomplete_data( probability = probability[cols] probability.columns = cols - return probability \ No newline at end of file + return probability diff --git a/causalnex/structure/__init__.py b/causalnex/structure/__init__.py index 190b745..f1e8cf4 100644 --- a/causalnex/structure/__init__.py +++ b/causalnex/structure/__init__.py @@ -38,8 +38,8 @@ "DAGRegressor", "DAGClassifier", "DynamicStructureModel", - "DynamicStructureNode" + "DynamicStructureNode", ] from .pytorch.sklearn import DAGClassifier, DAGRegressor -from .structuremodel import StructureModel, DynamicStructureModel, DynamicStructureNode \ No newline at end of file +from .structuremodel import DynamicStructureModel, DynamicStructureNode, StructureModel diff --git a/causalnex/structure/dynotears.py b/causalnex/structure/dynotears.py index 31f2ac6..c771e6b 100644 --- a/causalnex/structure/dynotears.py +++ b/causalnex/structure/dynotears.py @@ -31,15 +31,14 @@ """ import warnings -from typing import Dict, List, Tuple, Union +from typing import List, Tuple, Union import numpy as np import pandas as pd import scipy.linalg as slin import scipy.optimize as sopt -from causalnex.structure import DynamicStructureModel -from causalnex.structure import DynamicStructureNode +from causalnex.structure import DynamicStructureModel, DynamicStructureNode from causalnex.structure.transformers import DynamicDataTransformer @@ -125,7 +124,11 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments sm = DynamicStructureModel() sm.add_nodes( - [DynamicStructureNode(var, l_val) for var in col_idx.keys() for l_val in range(p + 1)] + [ + DynamicStructureNode(var, l_val) + for var in col_idx.keys() + for l_val in range(p + 1) + ] ) sm.add_weighted_edges_from( [ diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 350b626..129d7e5 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -31,15 +31,13 @@ ``StructureModel`` is a class that describes relationships between variables as a graph. """ -from distutils.ccompiler import new_compiler -from typing import Any, Hashable, List, Set, Tuple, Union, NamedTuple -from collections import Iterable +import types +from typing import Any, Hashable, List, NamedTuple, Set, Tuple, Union import networkx as nx import numpy as np from networkx.exception import NodeNotFound -import inspect -import types + def _validate_origin(origin: str) -> None: """ @@ -340,24 +338,27 @@ def get_markov_blanket( ) return blanket + class DynamicStructureNode(NamedTuple): """ Used by DynamicStructureModel to store each node as a (node_name, lag) pair """ + node: Union[int, str] time_step: int def get_node_name(self): - return f'{self.node}_lag{self.time_step}' + return f"{self.node}_lag{self.time_step}" def check_collection_type(c): return isinstance(c, (list, set, types.GeneratorType)) + def coerce_dsm_edges(arg): """ Used by DynamicStructureModel to convert edges as passed as primitive tuples to tuples of ``DynamicStructureNode``s. - An example input is [((0,0), (1,0), .5), ((1,0), (2,0), .5)]. This would be converted to + An example input is [((0,0), (1,0), .5), ((1,0), (2,0), .5)]. This would be converted to [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .5)] """ multi_edge = check_collection_type(arg) @@ -365,56 +366,105 @@ def coerce_dsm_edges(arg): if isinstance(arg, types.GeneratorType): arg = list(arg) if not all(isinstance(e, Tuple) for e in arg): - raise TypeError(f'Edges must be tuples containing 2 or 3 elements, received {arg}') - if all(isinstance(e[0], DynamicStructureNode) and isinstance(e[1], DynamicStructureNode) for e in arg): + raise TypeError( + f"Edges must be tuples containing 2 or 3 elements, received {arg}" + ) + if all( + isinstance(e[0], DynamicStructureNode) + and isinstance(e[1], DynamicStructureNode) + for e in arg + ): return arg else: new_arg = [] for e in arg: if not isinstance(e[0], Tuple) or not isinstance(e[1], Tuple): - raise TypeError(f'Nodes in {e} must be tuples with node name and time step') - elif isinstance(e[0], DynamicStructureNode) and isinstance(e[1], DynamicStructureNode): + raise TypeError( + f"Nodes in {e} must be tuples with node name and time step" + ) + elif isinstance(e[0], DynamicStructureNode) and isinstance( + e[1], DynamicStructureNode + ): new_arg.append(e) elif len(e) == 2: - if not isinstance(e[0], DynamicStructureNode) and not isinstance(e[1], DynamicStructureNode): - new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), DynamicStructureNode(e[1][0], e[1][1]))) + if not isinstance(e[0], DynamicStructureNode) and not isinstance( + e[1], DynamicStructureNode + ): + new_arg.append( + ( + DynamicStructureNode(e[0][0], e[0][1]), + DynamicStructureNode(e[1][0], e[1][1]), + ) + ) elif not isinstance(e[0], DynamicStructureNode): new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), e[1])) elif not isinstance(e[1], DynamicStructureNode): new_arg.append((e[0], DynamicStructureNode(e[1][0], e[1][1]))) elif len(e) == 3: - if not isinstance(e[0], DynamicStructureNode) and not isinstance(e[1], DynamicStructureNode): - new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), DynamicStructureNode(e[1][0], e[1][1]), e[2])) + if not isinstance(e[0], DynamicStructureNode) and not isinstance( + e[1], DynamicStructureNode + ): + new_arg.append( + ( + DynamicStructureNode(e[0][0], e[0][1]), + DynamicStructureNode(e[1][0], e[1][1]), + e[2], + ) + ) elif not isinstance(e[0], DynamicStructureNode): - new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), e[1], e[2])) + new_arg.append( + (DynamicStructureNode(e[0][0], e[0][1]), e[1], e[2]) + ) elif not isinstance(e[1], DynamicStructureNode): - new_arg.append((e[0], DynamicStructureNode(e[1][0], e[1][1]), e[2])) + new_arg.append( + (e[0], DynamicStructureNode(e[1][0], e[1][1]), e[2]) + ) else: - raise TypeError(f'Argument {e} must be a tuple containing 2 or 3 elements') + raise TypeError( + f"Argument {e} must be a tuple containing 2 or 3 elements" + ) return new_arg else: # if not isinstance(arg, Tuple): # raise TypeError(f'Edges must be tuples containing 2 or 3 elements, received {arg}') if not isinstance(arg[0], Tuple) or not isinstance(arg[1], Tuple): - raise TypeError(f'Nodes in {arg} must be tuples with node name and time step') - elif isinstance(arg[0], DynamicStructureNode) and isinstance(arg[1], DynamicStructureNode): - return arg + raise TypeError( + f"Nodes in {arg} must be tuples with node name and time step" + ) + elif isinstance(arg[0], DynamicStructureNode) and isinstance( + arg[1], DynamicStructureNode + ): + return arg elif len(arg) == 2: - if not isinstance(arg[0], DynamicStructureNode) and not isinstance(arg[1], DynamicStructureNode): - return (DynamicStructureNode(arg[0][0], arg[0][1]), DynamicStructureNode(arg[1][0], arg[1][1])) + if not isinstance(arg[0], DynamicStructureNode) and not isinstance( + arg[1], DynamicStructureNode + ): + return ( + DynamicStructureNode(arg[0][0], arg[0][1]), + DynamicStructureNode(arg[1][0], arg[1][1]), + ) elif not isinstance(arg[0], DynamicStructureNode): return (DynamicStructureNode(arg[0][0], arg[0][1]), arg[1]) elif not isinstance(arg[1], DynamicStructureNode): return (arg[0], DynamicStructureNode(arg[1][0], arg[1][1])) elif len(arg) == 3: - if not isinstance(arg[0], DynamicStructureNode) and not isinstance(arg[1], DynamicStructureNode): - return (DynamicStructureNode(arg[0][0], arg[0][1]), DynamicStructureNode(arg[1][0], arg[1][1]), arg[2]) + if not isinstance(arg[0], DynamicStructureNode) and not isinstance( + arg[1], DynamicStructureNode + ): + return ( + DynamicStructureNode(arg[0][0], arg[0][1]), + DynamicStructureNode(arg[1][0], arg[1][1]), + arg[2], + ) elif not isinstance(arg[0], DynamicStructureNode): return (DynamicStructureNode(arg[0][0], arg[0][1]), arg[1], arg[2]) elif not isinstance(arg[1], DynamicStructureNode): return (arg[0], DynamicStructureNode(arg[1][0], arg[1][1]), arg[2]) else: - raise TypeError(f'Argument {arg} must be either a DynamicStructureNode or tuple containing 2 or 3 elements') + raise TypeError( + f"Argument {arg} must be either a DynamicStructureNode or tuple containing 2 or 3 elements" + ) + def coerce_dsm_nodes(arg): """ @@ -434,7 +484,9 @@ def coerce_dsm_nodes(arg): elif isinstance(n, Tuple) and len(n) == 2: new_arg.append(DynamicStructureNode(n[0], n[1])) else: - raise TypeError(f'Argument {n} must be either a DynamicStructureNode or tuple containing 2 elements') + raise TypeError( + f"Argument {n} must be either a DynamicStructureNode or tuple containing 2 elements" + ) return new_arg else: if isinstance(arg, DynamicStructureNode): @@ -442,7 +494,10 @@ def coerce_dsm_nodes(arg): elif isinstance(arg, Tuple) and len(arg) == 2: return DynamicStructureNode(arg[0], arg[1]) else: - raise TypeError(f'Argument {arg} must be either a DynamicStructureNode or tuple containing 2 elements') + raise TypeError( + f"Argument {arg} must be either a DynamicStructureNode or tuple containing 2 elements" + ) + class DynamicStructureModel(StructureModel): """ @@ -482,11 +537,10 @@ def __init__(self, incoming_graph_data=None, origin="unknown", **attr): """ super().__init__(incoming_graph_data, origin, **attr) - def add_node(self, dnode: DynamicStructureNode): dnode = coerce_dsm_nodes(dnode) super().add_node(dnode) - + def add_nodes(self, dnodes: List[DynamicStructureNode]): dnodes = coerce_dsm_nodes(dnodes) super().add_nodes_from(dnodes) @@ -498,7 +552,9 @@ def to_directed_class(self): """ return DynamicStructureModel - def get_target_subgraph(self, node: DynamicStructureNode) -> "DynamicStructureModel": + def get_target_subgraph( + self, node: DynamicStructureNode + ) -> "DynamicStructureModel": """ Get the subgraph with the specified node. @@ -515,7 +571,10 @@ def get_target_subgraph(self, node: DynamicStructureNode) -> "DynamicStructureMo return super().get_target_subgraph(node) def get_markov_blanket( - self, nodes: Union[DynamicStructureNode, List[DynamicStructureNode], Set[DynamicStructureNode]] + self, + nodes: Union[ + DynamicStructureNode, List[DynamicStructureNode], Set[DynamicStructureNode] + ], ) -> "DynamicStructureModel": """ Get Markov blanket of specified target nodes @@ -532,7 +591,13 @@ def get_markov_blanket( nodes = coerce_dsm_nodes(nodes) return super().get_markov_blanket(nodes, DynamicStructureModel) - def add_edge(self, u: DynamicStructureNode, v: DynamicStructureNode, origin: str = "unknown", **attr): + def add_edge( + self, + u: DynamicStructureNode, + v: DynamicStructureNode, + origin: str = "unknown", + **attr, + ): edge = coerce_dsm_edges((u, v)) super().add_edge(edge[0], edge[1], origin, **attr) @@ -542,7 +607,10 @@ def add_edge(self, u: DynamicStructureNode, v: DynamicStructureNode, origin: str # integrate seamlessly, where edges will be given origin="unknown" where not provided def add_edges_from( self, - ebunch_to_add: Union[Set[Tuple[DynamicStructureNode, DynamicStructureNode]], List[Tuple[DynamicStructureNode, DynamicStructureNode]]], + ebunch_to_add: Union[ + Set[Tuple[DynamicStructureNode, DynamicStructureNode]], + List[Tuple[DynamicStructureNode, DynamicStructureNode]], + ], origin: str = "unknown", **attr, ): @@ -579,10 +647,13 @@ def add_edges_from( # integrate seamlessly, where edges will be given origin="unknown" where not provided def add_weighted_edges_from( self, - ebunch_to_add: Union[Set[Tuple[DynamicStructureNode, DynamicStructureNode, float]], List[Tuple[DynamicStructureNode, DynamicStructureNode, float]]], + ebunch_to_add: Union[ + Set[Tuple[DynamicStructureNode, DynamicStructureNode, float]], + List[Tuple[DynamicStructureNode, DynamicStructureNode, float]], + ], weight: str = "weight", origin: str = "unknown", - **attr + **attr, ): """ Adds a bunch of weighted causal relationships, u -> v. @@ -612,4 +683,6 @@ def add_weighted_edges_from( ebunch_to_add = coerce_dsm_edges(ebunch_to_add) if not isinstance(ebunch_to_add, list): ebunch_to_add = [ebunch_to_add] - super().add_weighted_edges_from(ebunch_to_add, weight=weight, origin=origin, **attr) \ No newline at end of file + super().add_weighted_edges_from( + ebunch_to_add, weight=weight, origin=origin, **attr + ) diff --git a/tests/structure/test_dynamicstructuremodel.py b/tests/structure/test_dynamicstructuremodel.py index 8cf3f84..ed4cf2b 100644 --- a/tests/structure/test_dynamicstructuremodel.py +++ b/tests/structure/test_dynamicstructuremodel.py @@ -26,11 +26,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + import pytest from networkx.exception import NodeNotFound from causalnex.structure import DynamicStructureModel, DynamicStructureNode -import re + class TestDynamicStructureModel: def test_init_has_origin(self): @@ -55,7 +57,9 @@ def test_edge_unknown_property(self): sm.add_edge((1, 0), (3, 0), origin="learned") sm.add_edge((1, 0), (4, 0), origin="expert") - assert sm.edges_with_origin("unknown") == [(DynamicStructureNode(1, 0), DynamicStructureNode(2, 0))] + assert sm.edges_with_origin("unknown") == [ + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) + ] def test_edge_learned_property(self): """should return only edges whose origin is unknown""" @@ -65,7 +69,9 @@ def test_edge_learned_property(self): sm.add_edge((1, 0), (3, 0), origin="learned") sm.add_edge((1, 0), (4, 0), origin="expert") - assert sm.edges_with_origin("learned") == [(DynamicStructureNode(1, 0), DynamicStructureNode(3, 0))] + assert sm.edges_with_origin("learned") == [ + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)) + ] def test_edge_expert_property(self): """should return only edges whose origin is unknown""" @@ -75,19 +81,31 @@ def test_edge_expert_property(self): sm.add_edge((1, 0), (3, 0), origin="learned") sm.add_edge((1, 0), (4, 0), origin="expert") - assert sm.edges_with_origin("expert") == [(DynamicStructureNode(1, 0), DynamicStructureNode(4, 0))] + assert sm.edges_with_origin("expert") == [ + (DynamicStructureNode(1, 0), DynamicStructureNode(4, 0)) + ] def test_to_directed(self): """should create a structure model""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + ] - edges = [(nodes[0], nodes[1]), (nodes[1], nodes[0]), (nodes[1], nodes[2]), (nodes[2], nodes[3])] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[0]), + (nodes[1], nodes[2]), + (nodes[2], nodes[3]), + ] sm.add_edges_from(edges) - + dag = sm.to_directed() - + assert isinstance(dag, DynamicStructureModel) assert all((edge[0], edge[1]) in dag.edges for edge in edges) @@ -95,14 +113,27 @@ def test_to_undirected(self): """should create an undirected Graph""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + ] - edges = [(nodes[0], nodes[1]), (nodes[1], nodes[0]), (nodes[1], nodes[2]), (nodes[2], nodes[3])] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[0]), + (nodes[1], nodes[2]), + (nodes[2], nodes[3]), + ] sm.add_edges_from(edges) udg = sm.to_undirected() - assert all((edge[0], edge[1]) in udg.edges for edge in [(nodes[1], nodes[2]), (nodes[2], nodes[3])]) + assert all( + (edge[0], edge[1]) in udg.edges + for edge in [(nodes[1], nodes[2]), (nodes[2], nodes[3])] + ) assert (nodes[0], nodes[1]) in udg.edges or (nodes[1], nodes[0]) in udg.edges assert len(udg.edges) == 3 @@ -115,7 +146,11 @@ def test_add_edge_default(self): sm.add_edge((1, 0), (2, 0)) assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges - assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "unknown") in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + "unknown", + ) in sm.edges.data("origin") def test_add_edge_unknown(self): """edges added with unknown origin should be labelled as unknown origin""" @@ -124,7 +159,11 @@ def test_add_edge_unknown(self): sm.add_edge((1, 0), (2, 0), "unknown") assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges - assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "unknown") in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + "unknown", + ) in sm.edges.data("origin") def test_add_edge_learned(self): """edges added with learned origin should be labelled as learned origin""" @@ -133,7 +172,11 @@ def test_add_edge_learned(self): sm.add_edge((1, 0), (2, 0), "learned") assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges - assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "learned") in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + "learned", + ) in sm.edges.data("origin") def test_add_edge_expert(self): """edges added with expert origin should be labelled as expert origin""" @@ -142,7 +185,11 @@ def test_add_edge_expert(self): sm.add_edge((1, 0), (2, 0), "expert") assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges - assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "expert") in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + "expert", + ) in sm.edges.data("origin") def test_add_edge_other(self): """edges added with other origin should throw an error""" @@ -159,16 +206,28 @@ def test_add_edge_custom_attr(self): sm.add_edge((1, 0), (2, 0), x="Y") assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)) in sm.edges - assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "Y") in sm.edges.data("x") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + "Y", + ) in sm.edges.data("x") def test_add_edge_multiple_times(self): """adding an edge again should update the edges origin attr""" sm = DynamicStructureModel() sm.add_edge((1, 0), (2, 0), origin="unknown") - assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "unknown") in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + "unknown", + ) in sm.edges.data("origin") sm.add_edge((1, 0), (2, 0), origin="learned") - assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "learned") in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + "learned", + ) in sm.edges.data("origin") def test_add_multiple_edges(self): """it should be possible to add multiple edges with different origins""" @@ -178,17 +237,33 @@ def test_add_multiple_edges(self): sm.add_edge((1, 0), (3, 0), origin="learned") sm.add_edge((1, 0), (4, 0), origin="expert") - assert (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), "unknown") in sm.edges.data("origin") - assert (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0), "learned") in sm.edges.data("origin") - assert (DynamicStructureNode(1, 0), DynamicStructureNode(4, 0), "expert") in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + "unknown", + ) in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(3, 0), + "learned", + ) in sm.edges.data("origin") + assert ( + DynamicStructureNode(1, 0), + DynamicStructureNode(4, 0), + "expert", + ) in sm.edges.data("origin") class TestDynamicStructureModelAddEdgesFrom: def test_add_edges_from_default(self): """edges added with default origin should be identified as unknown origin""" - print('******************* hello **************************') + print("******************* hello **************************") sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges) @@ -199,7 +274,11 @@ def test_add_edges_from_unknown(self): """edges added with unknown origin should be labelled as unknown origin""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, "unknown") @@ -211,7 +290,11 @@ def test_add_edges_from_learned(self): """edges added with learned origin should be labelled as learned origin""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, "learned") @@ -223,7 +306,11 @@ def test_add_edges_from_expert(self): """edges added with expert origin should be labelled as expert origin""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, "expert") @@ -244,7 +331,11 @@ def test_add_edges_from_custom_attr(self): """it should be possible to add edges with custom attributes""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, x="Y") @@ -256,7 +347,11 @@ def test_add_edges_from_multiple_times(self): """adding edges again should update the edges origin attr""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] sm.add_edges_from(edges, "unknown") @@ -268,7 +363,12 @@ def test_add_multiple_edges(self): """it should be possible to add multiple edges with different origins""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + ] sm.add_edges_from([(nodes[0], nodes[1])], origin="unknown") sm.add_edges_from([(nodes[0], nodes[2])], origin="learned") sm.add_edges_from([(nodes[0], nodes[3])], origin="expert") @@ -282,9 +382,13 @@ class TestDynamicStructureModelAddWeightedEdgesFrom: def test_add_weighted_edges_from_default(self): """edges added with default origin should be identified as unknown origin""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] - edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + edges = [(nodes[0], nodes[1], 0.5), (nodes[1], nodes[2], 0.5)] sm.add_weighted_edges_from(edges) assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) @@ -294,9 +398,13 @@ def test_add_weighted_edges_from_unknown(self): """edges added with unknown origin should be labelled as unknown origin""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] - edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + edges = [(nodes[0], nodes[1], 0.5), (nodes[1], nodes[2], 0.5)] sm.add_weighted_edges_from(edges, origin="unknown") assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) @@ -306,9 +414,13 @@ def test_add_weighted_edges_from_learned(self): """edges added with learned origin should be labelled as learned origin""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] - - edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] + + edges = [(nodes[0], nodes[1], 0.5), (nodes[1], nodes[2], 0.5)] sm.add_weighted_edges_from(edges, origin="learned") assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) @@ -318,10 +430,14 @@ def test_add_weighted_edges_from_expert(self): """edges added with expert origin should be labelled as expert origin""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] - #edges = [(1, 2, 0.5), (2, 3, 0.5)] - edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + # edges = [(1, 2, 0.5), (2, 3, 0.5)] + edges = [(nodes[0], nodes[1], 0.5), (nodes[1], nodes[2], 0.5)] sm.add_weighted_edges_from(edges, origin="expert") assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) @@ -332,7 +448,7 @@ def test_add_weighted_edges_from_other(self): sm = DynamicStructureModel() nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] - + with pytest.raises(ValueError, match="^Unknown origin: must be one of.*$"): sm.add_weighted_edges_from([(nodes[0], nodes[1], 0.5)], origin="other") @@ -340,10 +456,14 @@ def test_add_weighted_edges_from_custom_attr(self): """it should be possible to add edges with custom attributes""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] - - #edges = [(1, 2, 0.5), (2, 3, 0.5)] - edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] + + # edges = [(1, 2, 0.5), (2, 3, 0.5)] + edges = [(nodes[0], nodes[1], 0.5), (nodes[1], nodes[2], 0.5)] sm.add_weighted_edges_from(edges, x="Y") assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) @@ -353,14 +473,18 @@ def test_add_weighted_edges_from_multiple_times(self): """adding edges again should update the edges origin attr""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + ] + + # edges = [(1, 2, 0.5), (2, 3, 0.5)] + edges = [(nodes[0], nodes[1], 0.5), (nodes[1], nodes[2], 0.5)] - #edges = [(1, 2, 0.5), (2, 3, 0.5)] - edges = [(nodes[0], nodes[1], .5), (nodes[1], nodes[2], .5)] - sm.add_weighted_edges_from(edges, origin="unknown") assert all((u, v, "unknown") in sm.edges.data("origin") for u, v, _ in edges) - + sm.add_weighted_edges_from(edges, origin="learned") assert all((u, v, "learned") in sm.edges.data("origin") for u, v, _ in edges) @@ -368,7 +492,12 @@ def test_add_multiple_weighted_edges(self): """it should be possible to add multiple edges with different origins""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + ] sm.add_weighted_edges_from([(nodes[0], nodes[1], 0.5)], origin="unknown") sm.add_weighted_edges_from([(nodes[0], nodes[2], 0.5)], origin="learned") sm.add_weighted_edges_from([(nodes[0], nodes[3], 0.5)], origin="expert") @@ -383,37 +512,75 @@ def test_remove_edges_below_threshold(self): """Edges whose weight is less than a defined threshold should be removed""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)] - - strong_edges = [(nodes[0], nodes[1], 1.0), (nodes[0], nodes[2], 0.8), (nodes[0], nodes[4], 2.0)] - weak_edges = [(nodes[0], nodes[3], 0.4), (nodes[1], nodes[2], 0.6), (nodes[2], nodes[4], 0.5)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + ] + + strong_edges = [ + (nodes[0], nodes[1], 1.0), + (nodes[0], nodes[2], 0.8), + (nodes[0], nodes[4], 2.0), + ] + weak_edges = [ + (nodes[0], nodes[3], 0.4), + (nodes[1], nodes[2], 0.6), + (nodes[2], nodes[4], 0.5), + ] sm.add_weighted_edges_from(strong_edges) sm.add_weighted_edges_from(weak_edges) sm.remove_edges_below_threshold(0.7) - assert set(sm.edges(data="weight")) == set((u, v, w) for u, v, w in strong_edges) + assert set(sm.edges(data="weight")) == set( + (u, v, w) for u, v, w in strong_edges + ) def test_negative_weights(self): """Negative edges whose absolute value is greater than the defined threshold should not be removed""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)] - - strong_edges = [(nodes[0], nodes[1], -3.0), (nodes[2], nodes[0], 0.7), (nodes[0], nodes[4], -2.0)] - weak_edges = [(nodes[0], nodes[3], 0.4), (nodes[1], nodes[2], -0.6), (nodes[2], nodes[4], -0.5)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + ] + + strong_edges = [ + (nodes[0], nodes[1], -3.0), + (nodes[2], nodes[0], 0.7), + (nodes[0], nodes[4], -2.0), + ] + weak_edges = [ + (nodes[0], nodes[3], 0.4), + (nodes[1], nodes[2], -0.6), + (nodes[2], nodes[4], -0.5), + ] sm.add_weighted_edges_from(strong_edges) sm.add_weighted_edges_from(weak_edges) sm.remove_edges_below_threshold(0.7) - assert set(sm.edges(data="weight")) == set((u, v, w) for u, v, w in strong_edges) + assert set(sm.edges(data="weight")) == set( + (u, v, w) for u, v, w in strong_edges + ) def test_equal_weights(self): """Edges whose absolute value is equal to the defined threshold should not be removed""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + ] strong_edges = [(nodes[0], nodes[1], 1.0), (nodes[0], nodes[4], 2.0)] equal_edges = [(nodes[0], nodes[2], 0.6), (nodes[1], nodes[2], 0.6)] @@ -425,15 +592,19 @@ def test_equal_weights(self): sm.remove_edges_below_threshold(0.6) assert set(sm.edges(data="weight")) == set.union( - set((u, v, w) for u, v, w in strong_edges), - set((u, v, w) for u, v, w in equal_edges) + set((u, v, w) for u, v, w in strong_edges), + set((u, v, w) for u, v, w in equal_edges), ) def test_graph_with_no_edges(self): - """Can still run even if the graph is without edges""" + """Can still run even if the graph is without edges""" sm = DynamicStructureModel() # (var, lag) - all nodes here are in current timestep - nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)] + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + ] sm.add_nodes(nodes) sm.remove_edges_below_threshold(0.6) @@ -446,11 +617,32 @@ class TestDynamicStructureModelGetLargestSubgraph: @pytest.mark.parametrize( "test_input, expected", [ - ([(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0))]), - ([(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0))]), - - #([(0, 1), (1, 2), (1, 3), (4, 6)], [(0, 1), (1, 2), (1, 3)]), - #([(3, 4), (3, 5), (7, 6)], [(3, 4), (3, 5)]), + ( + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + ], + ), + ( + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + ], + ), + # ([(0, 1), (1, 2), (1, 3), (4, 6)], [(0, 1), (1, 2), (1, 3)]), + # ([(3, 4), (3, 5), (7, 6)], [(3, 4), (3, 5)]), ], ) def test_get_largest_subgraph(self, test_input, expected): @@ -469,8 +661,20 @@ def test_get_largest_subgraph(self, test_input, expected): def test_more_than_one_largest(self): """Return the first largest when there are more than one largest subgraph""" - nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)] - edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[3], nodes[4]), (nodes[3], nodes[5])] + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + ] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[3], nodes[4]), + (nodes[3], nodes[5]), + ] sm = DynamicStructureModel() sm.add_edges_from(edges) largest_subgraph = sm.get_largest_subgraph() @@ -492,23 +696,48 @@ def test_isolates(self): """Should return None if the structure model only contains isolates""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), - DynamicStructureNode(2, 0), DynamicStructureNode(7, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(7, 0), + ] sm.add_nodes(nodes) assert sm.get_largest_subgraph() is None def test_isolates_nodes_and_edges(self): """Should be able to return the largest subgraph""" - nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0), DynamicStructureNode(7, 0), DynamicStructureNode(8, 0), DynamicStructureNode(9, 0)] - edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[5], nodes[6])] + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + DynamicStructureNode(7, 0), + DynamicStructureNode(8, 0), + DynamicStructureNode(9, 0), + ] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[5], nodes[6]), + ] isolated_nodes = [nodes[7], nodes[8], nodes[9]] sm = DynamicStructureModel() sm.add_edges_from(edges) sm.add_nodes(isolated_nodes) largest_subgraph = sm.get_largest_subgraph() - expected_edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3])] + expected_edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + ] expected_graph = DynamicStructureModel() expected_graph.add_edges_from(expected_edges) @@ -517,7 +746,13 @@ def test_isolates_nodes_and_edges(self): def test_different_origins_and_weights(self): """The largest subgraph returned should still have the edge data preserved from the original graph""" - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + ] sm = DynamicStructureModel() sm.add_weighted_edges_from([(nodes[0], nodes[1], 2.0)], origin="unknown") sm.add_weighted_edges_from([(nodes[0], nodes[2], 1.0)], origin="learned") @@ -529,16 +764,56 @@ def test_different_origins_and_weights(self): (nodes[0], nodes[1], "unknown"), (nodes[0], nodes[2], "learned"), } - assert set(largest_subgraph.edges.data("weight")) == {(nodes[0], nodes[1], 2.0), (nodes[0], nodes[2], 1.0)} + assert set(largest_subgraph.edges.data("weight")) == { + (nodes[0], nodes[1], 2.0), + (nodes[0], nodes[2], 1.0), + } class TestDynamicStructureModelGetTargetSubgraph: @pytest.mark.parametrize( "target_node, test_input, expected", [ - (DynamicStructureNode(1, 0), [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0))]), - (DynamicStructureNode(3, 0), [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0))]), - (DynamicStructureNode(7, 0), [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(1, 0))], [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))]), + ( + DynamicStructureNode(1, 0), + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + ], + ), + ( + DynamicStructureNode(3, 0), + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + ], + ), + ( + DynamicStructureNode(7, 0), + [ + (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(5, 0), DynamicStructureNode(1, 0)), + ], + [ + (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + ), ], ) def test_get_target_subgraph(self, target_node, test_input, expected): @@ -557,14 +832,32 @@ def test_get_target_subgraph(self, target_node, test_input, expected): "target_node, test_input, expected", [ ( - DynamicStructureNode('a', 0), - [(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0)), (DynamicStructureNode('e', 0), DynamicStructureNode('f', 0))], - [(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0))], + DynamicStructureNode("a", 0), + [ + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), + (DynamicStructureNode("e", 0), DynamicStructureNode("f", 0)), + ], + [ + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), + ], ), ( - DynamicStructureNode('g', 0), - [(DynamicStructureNode('g', 0), DynamicStructureNode('h', 0)), (DynamicStructureNode('g', 0), DynamicStructureNode('z', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0))], - [(DynamicStructureNode('g', 0), DynamicStructureNode('h', 0)), (DynamicStructureNode('g', 0), DynamicStructureNode('z', 0))], + DynamicStructureNode("g", 0), + [ + (DynamicStructureNode("g", 0), DynamicStructureNode("h", 0)), + (DynamicStructureNode("g", 0), DynamicStructureNode("z", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), + ], + [ + (DynamicStructureNode("g", 0), DynamicStructureNode("h", 0)), + (DynamicStructureNode("g", 0), DynamicStructureNode("z", 0)), + ], ), ], ) @@ -584,13 +877,22 @@ def test_get_subgraph_string(self, target_node, test_input, expected): "target_node, test_input", [ ( - DynamicStructureNode(7, 0), - [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))] - ), + DynamicStructureNode(7, 0), + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), + ], + ), ( - DynamicStructureNode(1, 0), - [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))] - ) + DynamicStructureNode(1, 0), + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + ), ], ) def test_node_not_in_graph(self, target_node, test_input): @@ -608,23 +910,45 @@ def test_node_not_in_graph(self, target_node, test_input): def test_isolates(self): """Should return an isolated node""" - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), - DynamicStructureNode(2, 0), DynamicStructureNode(7, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(7, 0), + ] sm = DynamicStructureModel() sm.add_nodes(nodes) subgraph = sm.get_target_subgraph(DynamicStructureNode(1, 0)) expected_graph = DynamicStructureModel() expected_graph.add_node(DynamicStructureNode(1, 0)) - print(f'subgraph nodes {subgraph.nodes}\n') - print(f'expected nodes {expected_graph.nodes}') + print(f"subgraph nodes {subgraph.nodes}\n") + print(f"expected nodes {expected_graph.nodes}") assert set(subgraph.nodes) == set(expected_graph.nodes) assert set(subgraph.edges) == set(expected_graph.edges) def test_isolates_nodes_and_edges(self): """Should be able to return the subgraph with the specified node""" - nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0), DynamicStructureNode(7, 0), DynamicStructureNode(8, 0), DynamicStructureNode(9, 0)] - edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[5], nodes[6]), (nodes[4], nodes[5])] + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + DynamicStructureNode(7, 0), + DynamicStructureNode(8, 0), + DynamicStructureNode(9, 0), + ] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[5], nodes[6]), + (nodes[4], nodes[5]), + ] isolated_nodes = [nodes[7], nodes[8], nodes[9]] sm = DynamicStructureModel() sm.add_edges_from(edges) @@ -641,8 +965,14 @@ def test_different_origins_and_weights(self): """The subgraph returned should still have the edge data preserved from the original graph""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0)] - + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + ] + sm.add_weighted_edges_from([(nodes[0], nodes[1], 2.0)], origin="unknown") sm.add_weighted_edges_from([(nodes[0], nodes[2], 1.0)], origin="learned") sm.add_weighted_edges_from([(nodes[3], nodes[4], 0.7)], origin="expert") @@ -653,14 +983,31 @@ def test_different_origins_and_weights(self): (nodes[0], nodes[1], "unknown"), (nodes[0], nodes[2], "learned"), } - assert set(subgraph.edges.data("weight")) == {(nodes[0], nodes[1], 2.0), (nodes[0], nodes[2], 1.0)} + assert set(subgraph.edges.data("weight")) == { + (nodes[0], nodes[1], 2.0), + (nodes[0], nodes[2], 1.0), + } def test_instance_type(self): """The subgraph returned should still be a DynamicStructureModel instance""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)] - sm.add_edges_from([(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[4], nodes[5])]) + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(6, 0), + ] + sm.add_edges_from( + [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[4], nodes[5]), + ] + ) subgraph = sm.get_target_subgraph(nodes[2]) assert isinstance(subgraph, DynamicStructureModel) @@ -669,8 +1016,22 @@ def test_get_target_subgraph_twice(self): """get_target_subgraph should be able to run more than once""" sm = DynamicStructureModel() - nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)] - sm.add_edges_from([(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[4], nodes[5])]) + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(6, 0), + ] + sm.add_edges_from( + [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[4], nodes[5]), + ] + ) subgraph = sm.get_target_subgraph(nodes[0]) subgraph.remove_edge(nodes[0], nodes[1]) @@ -687,11 +1048,63 @@ def test_get_target_subgraph_twice(self): class TestDynamicStructureModelGetMarkovBlanket: @pytest.mark.parametrize( "target_node, test_input, expected", - [ - (DynamicStructureNode(1, 0), [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0))], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0))]), - (DynamicStructureNode(1, 0), [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0))], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0))]), - (DynamicStructureNode(3, 0), [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0))], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0))]), - (DynamicStructureNode(7, 0), [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0))], [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0))]), + [ + ( + DynamicStructureNode(1, 0), + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + ], + ), + ( + DynamicStructureNode(1, 0), + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0)), + ], + ), + ( + DynamicStructureNode(3, 0), + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), + ], + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + ], + ), + ( + DynamicStructureNode(7, 0), + [ + (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), + (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0)), + ], + [ + (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), + (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), + (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0)), + ], + ), ], ) def test_get_markov_blanket_single(self, target_node, test_input, expected): @@ -711,15 +1124,61 @@ def test_get_markov_blanket_single(self, target_node, test_input, expected): [ ( [DynamicStructureNode(1, 0), DynamicStructureNode(4, 0)], - [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0))], - [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0))], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(5, 0)), + ], + ), + ( + [DynamicStructureNode(2, 0), DynamicStructureNode(4, 0)], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0)), + ], + [ + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0)), + ], + ), + ( + [DynamicStructureNode(3, 0), DynamicStructureNode(6, 0)], + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), + ], + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), + ], ), - ([DynamicStructureNode(2, 0), DynamicStructureNode(4, 0)], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0))], [(DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(3, 0))]), - ([DynamicStructureNode(3, 0), DynamicStructureNode(6, 0)], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0))], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0))]), ( [DynamicStructureNode(2, 0), DynamicStructureNode(5, 0)], - [(DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0))], - [(DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0))], + [ + (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(6, 0), DynamicStructureNode(7, 0)), + (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0)), + ], + [ + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), + (DynamicStructureNode(5, 0), DynamicStructureNode(8, 0)), + ], ), ], ) @@ -739,14 +1198,31 @@ def test_get_markov_blanket_multiple(self, target_nodes, test_input, expected): "target_node, test_input, expected", [ ( - DynamicStructureNode('a', 0), - [(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0)), (DynamicStructureNode('e', 0), DynamicStructureNode('f', 0))], - [(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0))], + DynamicStructureNode("a", 0), + [ + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), + (DynamicStructureNode("e", 0), DynamicStructureNode("f", 0)), + ], + [ + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + ], ), ( - DynamicStructureNode('g', 0), - [(DynamicStructureNode('g', 0), DynamicStructureNode('h', 0)), (DynamicStructureNode('g', 0), DynamicStructureNode('z', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)), (DynamicStructureNode('a', 0), DynamicStructureNode('c', 0)), (DynamicStructureNode('c', 0), DynamicStructureNode('d', 0))], - [(DynamicStructureNode('g', 0), DynamicStructureNode('h', 0)), (DynamicStructureNode('g', 0), DynamicStructureNode('z', 0))], + DynamicStructureNode("g", 0), + [ + (DynamicStructureNode("g", 0), DynamicStructureNode("h", 0)), + (DynamicStructureNode("g", 0), DynamicStructureNode("z", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), + ], + [ + (DynamicStructureNode("g", 0), DynamicStructureNode("h", 0)), + (DynamicStructureNode("g", 0), DynamicStructureNode("z", 0)), + ], ), ], ) @@ -765,10 +1241,23 @@ def test_get_markov_blanket_string(self, target_node, test_input, expected): @pytest.mark.parametrize( "target_node, test_input", [ - (DynamicStructureNode(7, 0), [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))]), - (DynamicStructureNode(1, 0), [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))]), - #([DynamicStructureNode(1, 0), DynamicStructureNode(7, 0)], [(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0))]), - #([DynamicStructureNode(8, 0), DynamicStructureNode(2, 0)], [(DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0))]), + ( + DynamicStructureNode(7, 0), + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), + ], + ), + ( + DynamicStructureNode(1, 0), + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + ) ], ) def test_node_not_in_graph(self, target_node, test_input): @@ -786,8 +1275,13 @@ def test_node_not_in_graph(self, target_node, test_input): def test_isolates(self): """Should return an isolated node""" - nodes = [DynamicStructureNode(1, 0), DynamicStructureNode(3, 0), DynamicStructureNode(5, 0), - DynamicStructureNode(2, 0), DynamicStructureNode(7, 0)] + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(7, 0), + ] sm = DynamicStructureModel() sm.add_nodes(nodes) blanket = sm.get_markov_blanket(nodes[0]) @@ -801,10 +1295,25 @@ def test_isolates(self): def test_isolates_nodes_and_edges(self): """Should be able to return the subgraph with the specified node""" - nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(5, 0), DynamicStructureNode(6, 0), - DynamicStructureNode(7, 0), DynamicStructureNode(8, 0), DynamicStructureNode(9, 0)] - edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[5], nodes[6]), (nodes[4], nodes[5])] + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + DynamicStructureNode(7, 0), + DynamicStructureNode(8, 0), + DynamicStructureNode(9, 0), + ] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[5], nodes[6]), + (nodes[4], nodes[5]), + ] isolated_nodes = [nodes[7], nodes[8], nodes[9]] sm = DynamicStructureModel() sm.add_edges_from(edges) @@ -819,32 +1328,50 @@ def test_isolates_nodes_and_edges(self): def test_instance_type(self): """The subgraph returned should still be a DynamicStructureModel instance""" - nodes = [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)] - sm = DynamicStructureModel() - sm.add_edges_from([(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[1], nodes[3]), (nodes[4], nodes[5])]) + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(6, 0), + ] + sm = DynamicStructureModel() + sm.add_edges_from( + [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[4], nodes[5]), + ] + ) subgraph = sm.get_markov_blanket(nodes[2]) assert isinstance(subgraph, DynamicStructureModel) -class TestDynamicStructureModelEdgeCoercion: +class TestDynamicStructureModelEdgeCoercion: def test_edge_not_tuple(self): - edges = [((1, 0), (3, 0), .5), 6] + edges = [((1, 0), (3, 0), 0.5), 6] sm = DynamicStructureModel() - + with pytest.raises( TypeError, - match=re.escape(f"Edges must be tuples containing 2 or 3 elements, received {edges}"), + match=re.escape( + f"Edges must be tuples containing 2 or 3 elements, received {edges}" + ), ): sm.add_edges_from(edges) def test_multi_edge_not_dsn(self): - edges = [((0,0), (1,0)), ((1,0), (2,0))] + edges = [((0, 0), (1, 0)), ((1, 0), (2, 0))] sm = DynamicStructureModel() sm.add_edges_from(edges) - expected_edges = [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), (DynamicStructureNode(1,0), DynamicStructureNode(2,0))] + expected_edges = [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + ] expected_graph = DynamicStructureModel() expected_graph.add_edges_from(expected_edges) @@ -852,11 +1379,14 @@ def test_multi_edge_not_dsn(self): assert set(sm.edges) == set(expected_graph.edges) def test_weighted_multi_edge_not_dsn(self): - edges = [((0,0), (1,0), .5), ((1,0), (2,0), .7)] + edges = [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7)] sm = DynamicStructureModel() sm.add_weighted_edges_from(edges) - expected_edges = [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .7)] + expected_edges = [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), + ] expected_graph = DynamicStructureModel() expected_graph.add_weighted_edges_from(expected_edges) @@ -867,13 +1397,19 @@ def test_weighted_multi_edge_not_dsn(self): "input_edges, expected_edges", [ ( - [((0,0), (1,0), .5), ((1,0), (2,0), .7)], - [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .7)] + [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7)], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), + ], ), ( - [((0,0), (1,0)), ((1,0), (2,0))], - [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), (DynamicStructureNode(1,0), DynamicStructureNode(2,0))] - ) + [((0, 0), (1, 0)), ((1, 0), (2, 0))], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + ], + ), ], ) def test_multi_edge_dsn(self, input_edges, expected_edges): @@ -894,12 +1430,14 @@ def test_multi_edge_dsn(self, input_edges, expected_edges): assert set(sm.edges) == set(expected_graph.edges) def test_node_not_tuple(self): - edges = [((1, 0), (3, 0), .5), ((1, 0), 3, .7)] + edges = [((1, 0), (3, 0), 0.5), ((1, 0), 3, 0.7)] sm = DynamicStructureModel() - + with pytest.raises( TypeError, - match=re.escape(f"Nodes in {edges[1]} must be tuples with node name and time step"), + match=re.escape( + f"Nodes in {edges[1]} must be tuples with node name and time step" + ), ): sm.add_edges_from(edges) @@ -907,17 +1445,35 @@ def test_node_not_tuple(self): "input_edges, expected_edges", [ ( - [(DynamicStructureNode(0,0), (1,0), .5), ((1,0), DynamicStructureNode(2,0), .7)], - [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .7)] + [ + (DynamicStructureNode(0, 0), (1, 0), 0.5), + ((1, 0), DynamicStructureNode(2, 0), 0.7), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), + ], ), ( - [(DynamicStructureNode(0,0), (1,0)), ((1,0), DynamicStructureNode(2,0))], - [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), (DynamicStructureNode(1,0), DynamicStructureNode(2,0))] + [ + (DynamicStructureNode(0, 0), (1, 0)), + ((1, 0), DynamicStructureNode(2, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + ], ), ( - [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), ((1,0), DynamicStructureNode(2,0))], - [(DynamicStructureNode(0,0), DynamicStructureNode(1,0)), (DynamicStructureNode(1,0), DynamicStructureNode(2,0))] - ) + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + ((1, 0), DynamicStructureNode(2, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + ], + ), ], ) def test_multi_edge_one_dsn(self, input_edges, expected_edges): @@ -938,18 +1494,20 @@ def test_multi_edge_one_dsn(self, input_edges, expected_edges): assert set(sm.edges) == set(expected_graph.edges) def test_multi_edge_bad_tuple(self): - edges = [((0,0), (1,0), .5), ((1,0), (2,0), .7, .8)] + edges = [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7, 0.8)] sm = DynamicStructureModel() with pytest.raises( TypeError, - match=re.escape(f"Argument {edges[1]} must be a tuple containing 2 or 3 elements"), + match=re.escape( + f"Argument {edges[1]} must be a tuple containing 2 or 3 elements" + ), ): sm.add_weighted_edges_from(edges) # def test_single_edge_not_tuple(self): # edge = 6 # sm = DynamicStructureModel() - + # with pytest.raises( # TypeError, # match=re.escape(f"Edges must be tuples containing 2 or 3 elements, received {edge}"), @@ -960,10 +1518,12 @@ def test_single_edge_node_not_tuple(self): u = (1, 0) v = 3 sm = DynamicStructureModel() - + with pytest.raises( TypeError, - match=re.escape(f"Nodes in {(u, v)} must be tuples with node name and time step"), + match=re.escape( + f"Nodes in {(u, v)} must be tuples with node name and time step" + ), ): sm.add_edge(u, v) @@ -971,33 +1531,33 @@ def test_single_edge_node_not_tuple(self): "input_edge, expected_edge", [ ( - ((0,0), (1,0), .5), - (DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5) + ((0, 0), (1, 0), 0.5), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), ), ( - ((0,0), (1,0)), - (DynamicStructureNode(0,0), DynamicStructureNode(1,0)) + ((0, 0), (1, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), ), ( - (DynamicStructureNode(0,0), DynamicStructureNode(1,0)), - (DynamicStructureNode(0,0), DynamicStructureNode(1,0)) + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), ), ( - (DynamicStructureNode(0,0), (1,0)), - (DynamicStructureNode(0,0), DynamicStructureNode(1,0)) + (DynamicStructureNode(0, 0), (1, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), ), ( - ((0,0), DynamicStructureNode(1,0)), - (DynamicStructureNode(0,0), DynamicStructureNode(1,0)) + ((0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), ), ( - (DynamicStructureNode(0,0), (1,0), .5), - (DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5) + (DynamicStructureNode(0, 0), (1, 0), 0.5), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), ), ( - ((0,0), DynamicStructureNode(1,0), .5), - (DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5) - ) + ((0, 0), DynamicStructureNode(1, 0), 0.5), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + ), ], ) def test_single_edge_dsn(self, input_edge, expected_edge): @@ -1018,31 +1578,33 @@ def test_single_edge_dsn(self, input_edge, expected_edge): assert set(sm.edges) == set(expected_graph.edges) def test_single_edge_bad_tuple(self): - edge = ((1,0), (2,0), .7, .8) + edge = ((1, 0), (2, 0), 0.7, 0.8) sm = DynamicStructureModel() with pytest.raises( TypeError, - match=re.escape(f"Argument {edge} must be either a DynamicStructureNode or tuple containing 2 or 3 elements"), + match=re.escape( + f"Argument {edge} must be either a DynamicStructureNode or tuple containing 2 or 3 elements" + ), ): sm.add_weighted_edges_from(edge) -class TestDynamicStructureModelNodeCoercion: +class TestDynamicStructureModelNodeCoercion: @pytest.mark.parametrize( "input_nodes, expected_nodes", [ ( - [(0,0), (1,0)], - [DynamicStructureNode(0,0), DynamicStructureNode(1,0)] + [(0, 0), (1, 0)], + [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], ), ( - [DynamicStructureNode(0,0), (1,0)], - [DynamicStructureNode(0,0), DynamicStructureNode(1,0)] + [DynamicStructureNode(0, 0), (1, 0)], + [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], ), ( (DynamicStructureNode(n, 0) for n in range(2)), - [DynamicStructureNode(0,0), DynamicStructureNode(1,0)] - ) + [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], + ), ], ) def test_multi_node(self, input_nodes, expected_nodes): @@ -1057,14 +1619,8 @@ def test_multi_node(self, input_nodes, expected_nodes): @pytest.mark.parametrize( "input_node, expected_node", [ - ( - (0,0), - DynamicStructureNode(0,0) - ), - ( - DynamicStructureNode(0,0), - DynamicStructureNode(0,0) - ) + ((0, 0), DynamicStructureNode(0, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(0, 0)), ], ) def test_single_node(self, input_node, expected_node): @@ -1077,19 +1633,23 @@ def test_single_node(self, input_node, expected_node): assert set(sm.nodes) == set(expected_graph.nodes) def test_multi_node_bad_tuple(self): - nodes = [(0,0), (1,0,1)] + nodes = [(0, 0), (1, 0, 1)] sm = DynamicStructureModel() with pytest.raises( TypeError, - match=re.escape(f"Argument {nodes[1]} must be either a DynamicStructureNode or tuple containing 2 elements"), + match=re.escape( + f"Argument {nodes[1]} must be either a DynamicStructureNode or tuple containing 2 elements" + ), ): sm.add_nodes(nodes) def test_single_node_bad_tuple(self): - node = (1,0,1) + node = (1, 0, 1) sm = DynamicStructureModel() with pytest.raises( TypeError, - match=re.escape(f"Argument {node} must be either a DynamicStructureNode or tuple containing 2 elements"), + match=re.escape( + f"Argument {node} must be either a DynamicStructureNode or tuple containing 2 elements" + ), ): - sm.add_node(node) \ No newline at end of file + sm.add_node(node) diff --git a/tests/structure/test_dynotears.py b/tests/structure/test_dynotears.py index 3aa3c58..6ee0636 100644 --- a/tests/structure/test_dynotears.py +++ b/tests/structure/test_dynotears.py @@ -26,7 +26,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from causalnex.structure.structuremodel import DynamicStructureNode import networkx as nx import numpy as np @@ -34,6 +33,7 @@ import pytest from causalnex.structure.dynotears import from_numpy_dynamic, from_pandas_dynamic +from causalnex.structure.structuremodel import DynamicStructureNode class TestFromNumpyDynotears: @@ -136,7 +136,9 @@ def test_expected_structure_learned_p1(self, data_dynotears_p1): edges_in_sm_and_a = [el for el in sm.edges if el in a_edges] sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0].get_node_name()] - assert sorted([el for el in sm.edges if "lag0" in el[0].get_node_name()]) == sorted(w_edges) + assert sorted( + [el for el in sm.edges if "lag0" in el[0].get_node_name()] + ) == sorted(w_edges) assert len(edges_in_sm_and_a) / len(a_edges) > 0.6 assert len(edges_in_sm_and_a) / len(sm_inter_edges) > 0.9 @@ -279,8 +281,8 @@ def test_certain_relationships_get_near_certain_weight(self): ) sm = from_numpy_dynamic(data.values[1:], data.values[:-1], w_threshold=0.1) edge = ( - sm.get_edge_data(DynamicStructureNode(1, 0), DynamicStructureNode(0, 0)) or - sm.get_edge_data(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)) + sm.get_edge_data(DynamicStructureNode(1, 0), DynamicStructureNode(0, 0)) + or sm.get_edge_data(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)) )["weight"] assert 0.99 < edge <= 1.01 @@ -294,8 +296,8 @@ def test_inverse_relationships_get_negative_weight(self): ) sm = from_numpy_dynamic(data.values[1:], data.values[:-1], w_threshold=0.1) edge = ( - sm.get_edge_data(DynamicStructureNode(1, 0), DynamicStructureNode(0, 0)) or - sm.get_edge_data(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)) + sm.get_edge_data(DynamicStructureNode(1, 0), DynamicStructureNode(0, 0)) + or sm.get_edge_data(DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)) )["weight"] assert -1.01 < edge <= -0.99 @@ -419,7 +421,9 @@ def test_expected_structure_learned_p1(self, data_dynotears_p1): edges_in_sm_and_a = [el for el in sm.edges if el in a_edges] sm_inter_edges = [el for el in sm.edges if "lag0" not in el[0].get_node_name()] - assert sorted(el for el in sm.edges if "lag0" in el[0].get_node_name()) == sorted(w_edges) + assert sorted( + el for el in sm.edges if "lag0" in el[0].get_node_name() + ) == sorted(w_edges) assert len(edges_in_sm_and_a) / len(a_edges) > 0.6 assert len(edges_in_sm_and_a) / len(sm_inter_edges) > 0.9 @@ -581,8 +585,10 @@ def test_certain_relationships_get_near_certain_weight(self): ) sm = from_pandas_dynamic(data, p=1, w_threshold=0.1) edge = ( - sm.get_edge_data(DynamicStructureNode('b', 0), DynamicStructureNode('a', 0)) or - sm.get_edge_data(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)) + sm.get_edge_data(DynamicStructureNode("b", 0), DynamicStructureNode("a", 0)) + or sm.get_edge_data( + DynamicStructureNode("a", 0), DynamicStructureNode("b", 0) + ) )["weight"] assert 0.99 < edge <= 1.01 @@ -596,8 +602,10 @@ def test_inverse_relationships_get_negative_weight(self): ) sm = from_pandas_dynamic(data, p=1, w_threshold=0.1) edge = ( - sm.get_edge_data(DynamicStructureNode('b', 0), DynamicStructureNode('a', 0)) or - sm.get_edge_data(DynamicStructureNode('a', 0), DynamicStructureNode('b', 0)) + sm.get_edge_data(DynamicStructureNode("b", 0), DynamicStructureNode("a", 0)) + or sm.get_edge_data( + DynamicStructureNode("a", 0), DynamicStructureNode("b", 0) + ) )["weight"] assert -1.01 < edge <= -0.99 From 85df54fe15f19699a601c73cd93708e686ebf308 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Mon, 17 Oct 2022 18:41:56 -0500 Subject: [PATCH 08/13] refactored coercion functions for linter --- causalnex/structure/structuremodel.py | 265 ++++++++++++++------------ 1 file changed, 143 insertions(+), 122 deletions(-) diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 129d7e5..76680fd 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -348,122 +348,146 @@ class DynamicStructureNode(NamedTuple): time_step: int def get_node_name(self): + """ + Naming convention for dynamic nodes + """ return f"{self.node}_lag{self.time_step}" def check_collection_type(c): + """ + Check if data structure is a collection + """ return isinstance(c, (list, set, types.GeneratorType)) -def coerce_dsm_edges(arg): +def convert_to_dsm_edges(arg): """ - Used by DynamicStructureModel to convert edges as passed as primitive tuples to tuples of ``DynamicStructureNode``s. - An example input is [((0,0), (1,0), .5), ((1,0), (2,0), .5)]. This would be converted to - [(DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .5)] + If not all edges passed to DynamicStructureModel are tuples of DynamicStructureNode, convert them """ - multi_edge = check_collection_type(arg) - if multi_edge: - if isinstance(arg, types.GeneratorType): - arg = list(arg) - if not all(isinstance(e, Tuple) for e in arg): - raise TypeError( - f"Edges must be tuples containing 2 or 3 elements, received {arg}" - ) - if all( - isinstance(e[0], DynamicStructureNode) - and isinstance(e[1], DynamicStructureNode) - for e in arg - ): - return arg - else: - new_arg = [] - for e in arg: - if not isinstance(e[0], Tuple) or not isinstance(e[1], Tuple): - raise TypeError( - f"Nodes in {e} must be tuples with node name and time step" - ) - elif isinstance(e[0], DynamicStructureNode) and isinstance( - e[1], DynamicStructureNode - ): - new_arg.append(e) - elif len(e) == 2: - if not isinstance(e[0], DynamicStructureNode) and not isinstance( - e[1], DynamicStructureNode - ): - new_arg.append( - ( - DynamicStructureNode(e[0][0], e[0][1]), - DynamicStructureNode(e[1][0], e[1][1]), - ) - ) - elif not isinstance(e[0], DynamicStructureNode): - new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), e[1])) - elif not isinstance(e[1], DynamicStructureNode): - new_arg.append((e[0], DynamicStructureNode(e[1][0], e[1][1]))) - elif len(e) == 3: - if not isinstance(e[0], DynamicStructureNode) and not isinstance( - e[1], DynamicStructureNode - ): - new_arg.append( - ( - DynamicStructureNode(e[0][0], e[0][1]), - DynamicStructureNode(e[1][0], e[1][1]), - e[2], - ) - ) - elif not isinstance(e[0], DynamicStructureNode): - new_arg.append( - (DynamicStructureNode(e[0][0], e[0][1]), e[1], e[2]) - ) - elif not isinstance(e[1], DynamicStructureNode): - new_arg.append( - (e[0], DynamicStructureNode(e[1][0], e[1][1]), e[2]) - ) - else: - raise TypeError( - f"Argument {e} must be a tuple containing 2 or 3 elements" - ) - return new_arg - else: - # if not isinstance(arg, Tuple): - # raise TypeError(f'Edges must be tuples containing 2 or 3 elements, received {arg}') - if not isinstance(arg[0], Tuple) or not isinstance(arg[1], Tuple): - raise TypeError( - f"Nodes in {arg} must be tuples with node name and time step" - ) - elif isinstance(arg[0], DynamicStructureNode) and isinstance( - arg[1], DynamicStructureNode + new_arg = [] + for e in arg: + if not isinstance(e[0], Tuple) or not isinstance(e[1], Tuple): + raise TypeError(f"Nodes in {e} must be tuples with node name and time step") + if isinstance(e[0], DynamicStructureNode) and isinstance( + e[1], DynamicStructureNode ): - return arg - elif len(arg) == 2: - if not isinstance(arg[0], DynamicStructureNode) and not isinstance( - arg[1], DynamicStructureNode + new_arg.append(e) + elif len(e) == 2: + if not isinstance(e[0], DynamicStructureNode) and not isinstance( + e[1], DynamicStructureNode ): - return ( - DynamicStructureNode(arg[0][0], arg[0][1]), - DynamicStructureNode(arg[1][0], arg[1][1]), + new_arg.append( + ( + DynamicStructureNode(e[0][0], e[0][1]), + DynamicStructureNode(e[1][0], e[1][1]), + ) ) - elif not isinstance(arg[0], DynamicStructureNode): - return (DynamicStructureNode(arg[0][0], arg[0][1]), arg[1]) - elif not isinstance(arg[1], DynamicStructureNode): - return (arg[0], DynamicStructureNode(arg[1][0], arg[1][1])) - elif len(arg) == 3: - if not isinstance(arg[0], DynamicStructureNode) and not isinstance( - arg[1], DynamicStructureNode + elif not isinstance(e[0], DynamicStructureNode): + new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), e[1])) + elif not isinstance(e[1], DynamicStructureNode): + new_arg.append((e[0], DynamicStructureNode(e[1][0], e[1][1]))) + elif len(e) == 3: + if not isinstance(e[0], DynamicStructureNode) and not isinstance( + e[1], DynamicStructureNode ): - return ( - DynamicStructureNode(arg[0][0], arg[0][1]), - DynamicStructureNode(arg[1][0], arg[1][1]), - arg[2], + new_arg.append( + ( + DynamicStructureNode(e[0][0], e[0][1]), + DynamicStructureNode(e[1][0], e[1][1]), + e[2], + ) ) - elif not isinstance(arg[0], DynamicStructureNode): - return (DynamicStructureNode(arg[0][0], arg[0][1]), arg[1], arg[2]) - elif not isinstance(arg[1], DynamicStructureNode): - return (arg[0], DynamicStructureNode(arg[1][0], arg[1][1]), arg[2]) + elif not isinstance(e[0], DynamicStructureNode): + new_arg.append((DynamicStructureNode(e[0][0], e[0][1]), e[1], e[2])) + elif not isinstance(e[1], DynamicStructureNode): + new_arg.append((e[0], DynamicStructureNode(e[1][0], e[1][1]), e[2])) else: - raise TypeError( - f"Argument {arg} must be either a DynamicStructureNode or tuple containing 2 or 3 elements" - ) + raise TypeError(f"Argument {e} must be a tuple containing 2 or 3 elements") + return new_arg + + +def convert_single_edge(arg): + """ + Used by coerce_dsm_edges to convert a single non weighted edge + """ + if not isinstance(arg[0], DynamicStructureNode) and not isinstance( + arg[1], DynamicStructureNode + ): + return ( + DynamicStructureNode(arg[0][0], arg[0][1]), + DynamicStructureNode(arg[1][0], arg[1][1]), + ) + if not isinstance(arg[0], DynamicStructureNode): + return (DynamicStructureNode(arg[0][0], arg[0][1]), arg[1]) + if not isinstance(arg[1], DynamicStructureNode): + return (arg[0], DynamicStructureNode(arg[1][0], arg[1][1])) + return arg + + +def convert_single_weighted_edge(arg): + """ + Used by coerce_dsm_edges to convert a single weighted edge + """ + if not isinstance(arg[0], DynamicStructureNode) and not isinstance( + arg[1], DynamicStructureNode + ): + return ( + DynamicStructureNode(arg[0][0], arg[0][1]), + DynamicStructureNode(arg[1][0], arg[1][1]), + arg[2], + ) + if not isinstance(arg[0], DynamicStructureNode): + return (DynamicStructureNode(arg[0][0], arg[0][1]), arg[1], arg[2]) + if not isinstance(arg[1], DynamicStructureNode): + return (arg[0], DynamicStructureNode(arg[1][0], arg[1][1]), arg[2]) + return arg + + +def coerce_dsm_multi_edge(arg): + """ + Coerce arguments containing multiple edges passed to DynamicStructureModel methods + """ + if isinstance(arg, types.GeneratorType): + arg = list(arg) + if not all(isinstance(e, Tuple) for e in arg): + raise TypeError( + f"Edges must be tuples containing 2 or 3 elements, received {arg}" + ) + if all( + isinstance(e[0], DynamicStructureNode) + and isinstance(e[1], DynamicStructureNode) + for e in arg + ): + return arg + return convert_to_dsm_edges(arg) + + +def coerce_dsm_edges(arg): + """ + Used by DynamicStructureModel to convert edges as passed as primitive tuples to tuples of ``DynamicStructureNode``s. + An example input is [((0,0), (1,0), .5), ((1,0), (2,0), .5)]. This would be converted to + [ + (DynamicStructureNode(0,0), DynamicStructureNode(1,0), .5), + (DynamicStructureNode(1,0), DynamicStructureNode(2,0), .5) + ] + """ + multi_edge = check_collection_type(arg) + if multi_edge: + return coerce_dsm_multi_edge(arg) + if not isinstance(arg[0], Tuple) or not isinstance(arg[1], Tuple): + raise TypeError(f"Nodes in {arg} must be tuples with node name and time step") + if isinstance(arg[0], DynamicStructureNode) and isinstance( + arg[1], DynamicStructureNode + ): + return arg + if len(arg) == 2: + return convert_single_edge(arg) + if len(arg) == 3: + return convert_single_weighted_edge(arg) + raise TypeError( + f"Argument {arg} must be either a DynamicStructureNode or tuple containing 2 or 3 elements" + ) def coerce_dsm_nodes(arg): @@ -476,27 +500,24 @@ def coerce_dsm_nodes(arg): arg = list(arg) if all(isinstance(n, DynamicStructureNode) for n in arg): return arg - else: - new_arg = [] - for n in arg: - if isinstance(n, DynamicStructureNode): - new_arg.append(n) - elif isinstance(n, Tuple) and len(n) == 2: - new_arg.append(DynamicStructureNode(n[0], n[1])) - else: - raise TypeError( - f"Argument {n} must be either a DynamicStructureNode or tuple containing 2 elements" - ) - return new_arg - else: - if isinstance(arg, DynamicStructureNode): - return arg - elif isinstance(arg, Tuple) and len(arg) == 2: - return DynamicStructureNode(arg[0], arg[1]) - else: - raise TypeError( - f"Argument {arg} must be either a DynamicStructureNode or tuple containing 2 elements" - ) + new_arg = [] + for n in arg: + if isinstance(n, DynamicStructureNode): + new_arg.append(n) + elif isinstance(n, Tuple) and len(n) == 2: + new_arg.append(DynamicStructureNode(n[0], n[1])) + else: + raise TypeError( + f"Argument {n} must be either a DynamicStructureNode or tuple containing 2 elements" + ) + return new_arg + if isinstance(arg, DynamicStructureNode): + return arg + if isinstance(arg, Tuple) and len(arg) == 2: + return DynamicStructureNode(arg[0], arg[1]) + raise TypeError( + f"Argument {arg} must be either a DynamicStructureNode or tuple containing 2 elements" + ) class DynamicStructureModel(StructureModel): From e34376ca59c7568045b9bb941a801c85b504f8e6 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Mon, 17 Oct 2022 20:45:40 -0500 Subject: [PATCH 09/13] refactor dsn tests for linter --- causalnex/structure/structuremodel.py | 36 +- tests/structure/test_dsn_coercion.py | 328 ++++++++ tests/structure/test_dsn_subgraph.py | 465 +++++++++++ tests/structure/test_dynamicstructuremodel.py | 741 +----------------- 4 files changed, 805 insertions(+), 765 deletions(-) create mode 100644 tests/structure/test_dsn_coercion.py create mode 100644 tests/structure/test_dsn_subgraph.py diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 76680fd..e473c77 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -538,31 +538,14 @@ class DynamicStructureModel(StructureModel): Edges are represented as links between nodes with optional key/value attributes. """ - def __init__(self, incoming_graph_data=None, origin="unknown", **attr): - """ - Create a ``DynamicStructureModel`` with incoming_graph_data, which has come from some origin. - - Args: - incoming_graph_data (Optional): input graph (optional, default: None) - Data to initialize graph. If None (default) an empty graph is created. - The data can be any format that is supported by the to_networkx_graph() - function, currently including edge list, dict of dicts, dict of lists, - NetworkX graph, NumPy matrix or 2d ndarray, SciPy sparse matrix, or PyGraphviz graph. - - origin (str): label for how the edges were created. Can be one of: - - unknown: edges exist for an unknown reason; - - learned: edges were created as the output of a machine-learning process; - - expert: edges were created by a domain expert. - - attr : Attributes to add to graph as key/value pairs (no attributes by default). - """ - super().__init__(incoming_graph_data, origin, **attr) - - def add_node(self, dnode: DynamicStructureNode): - dnode = coerce_dsm_nodes(dnode) - super().add_node(dnode) + def add_node(self, node_for_adding: DynamicStructureNode, **attr): + dnode = coerce_dsm_nodes(node_for_adding) + super().add_node(dnode, **attr) def add_nodes(self, dnodes: List[DynamicStructureNode]): + """ + Add multiple `DynamicStructureNode` to graph + """ dnodes = coerce_dsm_nodes(dnodes) super().add_nodes_from(dnodes) @@ -596,6 +579,7 @@ def get_markov_blanket( nodes: Union[ DynamicStructureNode, List[DynamicStructureNode], Set[DynamicStructureNode] ], + cls: nx.DiGraph = None, ) -> "DynamicStructureModel": """ Get Markov blanket of specified target nodes @@ -614,12 +598,12 @@ def get_markov_blanket( def add_edge( self, - u: DynamicStructureNode, - v: DynamicStructureNode, + u_of_edge: DynamicStructureNode, + v_of_edge: DynamicStructureNode, origin: str = "unknown", **attr, ): - edge = coerce_dsm_edges((u, v)) + edge = coerce_dsm_edges((u_of_edge, v_of_edge)) super().add_edge(edge[0], edge[1], origin, **attr) # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) diff --git a/tests/structure/test_dsn_coercion.py b/tests/structure/test_dsn_coercion.py new file mode 100644 index 0000000..81862ba --- /dev/null +++ b/tests/structure/test_dsn_coercion.py @@ -0,0 +1,328 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import pytest + +from causalnex.structure import DynamicStructureModel, DynamicStructureNode + + +class TestDynamicStructureModelEdgeCoercion: + def test_edge_not_tuple(self): + edges = [((1, 0), (3, 0), 0.5), 6] + sm = DynamicStructureModel() + + with pytest.raises( + TypeError, + match=re.escape( + f"Edges must be tuples containing 2 or 3 elements, received {edges}" + ), + ): + sm.add_edges_from(edges) + + def test_multi_edge_not_dsn(self): + edges = [((0, 0), (1, 0)), ((1, 0), (2, 0))] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + + expected_edges = [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + ] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + def test_weighted_multi_edge_not_dsn(self): + edges = [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7)] + sm = DynamicStructureModel() + sm.add_weighted_edges_from(edges) + + expected_edges = [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), + ] + expected_graph = DynamicStructureModel() + expected_graph.add_weighted_edges_from(expected_edges) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "input_edges, expected_edges", + [ + ( + [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7)], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), + ], + ), + ( + [((0, 0), (1, 0)), ((1, 0), (2, 0))], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + ], + ), + ], + ) + def test_multi_edge_dsn(self, input_edges, expected_edges): + sm = DynamicStructureModel() + weighted = len(input_edges[0]) == 3 + if not weighted: + sm.add_edges_from(input_edges) + else: + sm.add_weighted_edges_from(input_edges) + + expected_graph = DynamicStructureModel() + if not weighted: + expected_graph.add_edges_from(expected_edges) + else: + expected_graph.add_weighted_edges_from(expected_edges) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + def test_node_not_tuple(self): + edges = [((1, 0), (3, 0), 0.5), ((1, 0), 3, 0.7)] + sm = DynamicStructureModel() + + with pytest.raises( + TypeError, + match=re.escape( + f"Nodes in {edges[1]} must be tuples with node name and time step" + ), + ): + sm.add_edges_from(edges) + + @pytest.mark.parametrize( + "input_edges, expected_edges", + [ + ( + [ + (DynamicStructureNode(0, 0), (1, 0), 0.5), + ((1, 0), DynamicStructureNode(2, 0), 0.7), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), + ], + ), + ( + [ + (DynamicStructureNode(0, 0), (1, 0)), + ((1, 0), DynamicStructureNode(2, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + ], + ), + ( + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + ((1, 0), DynamicStructureNode(2, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + ], + ), + ], + ) + def test_multi_edge_one_dsn(self, input_edges, expected_edges): + sm = DynamicStructureModel() + weighted = len(input_edges[0]) == 3 + if not weighted: + sm.add_edges_from(input_edges) + else: + sm.add_weighted_edges_from(input_edges) + + expected_graph = DynamicStructureModel() + if not weighted: + expected_graph.add_edges_from(expected_edges) + else: + expected_graph.add_weighted_edges_from(expected_edges) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + def test_multi_edge_bad_tuple(self): + edges = [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7, 0.8)] + sm = DynamicStructureModel() + with pytest.raises( + TypeError, + match=re.escape( + f"Argument {edges[1]} must be a tuple containing 2 or 3 elements" + ), + ): + sm.add_weighted_edges_from(edges) + + def test_single_edge_node_not_tuple(self): + u = (1, 0) + v = 3 + sm = DynamicStructureModel() + + with pytest.raises( + TypeError, + match=re.escape( + f"Nodes in {(u, v)} must be tuples with node name and time step" + ), + ): + sm.add_edge(u, v) + + @pytest.mark.parametrize( + "input_edge, expected_edge", + [ + ( + ((0, 0), (1, 0), 0.5), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + ), + ( + ((0, 0), (1, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + ), + ( + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + ), + ( + (DynamicStructureNode(0, 0), (1, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + ), + ( + ((0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + ), + ( + (DynamicStructureNode(0, 0), (1, 0), 0.5), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + ), + ( + ((0, 0), DynamicStructureNode(1, 0), 0.5), + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), + ), + ], + ) + def test_single_edge_dsn(self, input_edge, expected_edge): + sm = DynamicStructureModel() + weighted = len(input_edge) == 3 + if not weighted: + sm.add_edge(input_edge[0], input_edge[1]) + else: + sm.add_weighted_edges_from(input_edge) + + expected_graph = DynamicStructureModel() + if not weighted: + expected_graph.add_edge(expected_edge[0], expected_edge[1]) + else: + expected_graph.add_weighted_edges_from(expected_edge) + + assert set(sm.nodes) == set(expected_graph.nodes) + assert set(sm.edges) == set(expected_graph.edges) + + def test_single_edge_bad_tuple(self): + edge = ((1, 0), (2, 0), 0.7, 0.8) + sm = DynamicStructureModel() + with pytest.raises( + TypeError, + match=re.escape( + f"Argument {edge} must be either a DynamicStructureNode or tuple containing 2 or 3 elements" + ), + ): + sm.add_weighted_edges_from(edge) + + +class TestDynamicStructureModelNodeCoercion: + @pytest.mark.parametrize( + "input_nodes, expected_nodes", + [ + ( + [(0, 0), (1, 0)], + [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], + ), + ( + [DynamicStructureNode(0, 0), (1, 0)], + [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], + ), + ( + (DynamicStructureNode(n, 0) for n in range(2)), + [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], + ), + ], + ) + def test_multi_node(self, input_nodes, expected_nodes): + sm = DynamicStructureModel() + sm.add_nodes(input_nodes) + + expected_graph = DynamicStructureModel() + expected_graph.add_nodes(expected_nodes) + + assert set(sm.nodes) == set(expected_graph.nodes) + + @pytest.mark.parametrize( + "input_node, expected_node", + [ + ((0, 0), DynamicStructureNode(0, 0)), + (DynamicStructureNode(0, 0), DynamicStructureNode(0, 0)), + ], + ) + def test_single_node(self, input_node, expected_node): + sm = DynamicStructureModel() + sm.add_nodes(input_node) + + expected_graph = DynamicStructureModel() + expected_graph.add_nodes(expected_node) + + assert set(sm.nodes) == set(expected_graph.nodes) + + def test_multi_node_bad_tuple(self): + nodes = [(0, 0), (1, 0, 1)] + sm = DynamicStructureModel() + with pytest.raises( + TypeError, + match=re.escape( + f"Argument {nodes[1]} must be either a DynamicStructureNode or tuple containing 2 elements" + ), + ): + sm.add_nodes(nodes) + + def test_single_node_bad_tuple(self): + node = (1, 0, 1) + sm = DynamicStructureModel() + with pytest.raises( + TypeError, + match=re.escape( + f"Argument {node} must be either a DynamicStructureNode or tuple containing 2 elements" + ), + ): + sm.add_node(node) diff --git a/tests/structure/test_dsn_subgraph.py b/tests/structure/test_dsn_subgraph.py new file mode 100644 index 0000000..aee3e08 --- /dev/null +++ b/tests/structure/test_dsn_subgraph.py @@ -0,0 +1,465 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import pytest +from networkx.exception import NodeNotFound + +from causalnex.structure import DynamicStructureModel, DynamicStructureNode + +class TestDynamicStructureModelGetLargestSubgraph: + @pytest.mark.parametrize( + "test_input, expected", + [ + ( + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + ], + ), + ( + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + ], + ), + # ([(0, 1), (1, 2), (1, 3), (4, 6)], [(0, 1), (1, 2), (1, 3)]), + # ([(3, 4), (3, 5), (7, 6)], [(3, 4), (3, 5)]), + ], + ) + def test_get_largest_subgraph(self, test_input, expected): + """Should be able to return the largest subgraph""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + largest_subgraph = sm.get_largest_subgraph() + + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_more_than_one_largest(self): + """Return the first largest when there are more than one largest subgraph""" + + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + ] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[3], nodes[4]), + (nodes[3], nodes[5]), + ] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + largest_subgraph = sm.get_largest_subgraph() + + expected_edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_empty(self): + """Should return None if the structure model is empty""" + + sm = DynamicStructureModel() + assert sm.get_largest_subgraph() is None + + def test_isolates(self): + """Should return None if the structure model only contains isolates""" + + sm = DynamicStructureModel() + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(7, 0), + ] + sm.add_nodes(nodes) + assert sm.get_largest_subgraph() is None + + def test_isolates_nodes_and_edges(self): + """Should be able to return the largest subgraph""" + + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + DynamicStructureNode(7, 0), + DynamicStructureNode(8, 0), + DynamicStructureNode(9, 0), + ] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[5], nodes[6]), + ] + isolated_nodes = [nodes[7], nodes[8], nodes[9]] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + sm.add_nodes(isolated_nodes) + largest_subgraph = sm.get_largest_subgraph() + + expected_edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + ] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_different_origins_and_weights(self): + """The largest subgraph returned should still have the edge data preserved from the original graph""" + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + ] + sm = DynamicStructureModel() + sm.add_weighted_edges_from([(nodes[0], nodes[1], 2.0)], origin="unknown") + sm.add_weighted_edges_from([(nodes[0], nodes[2], 1.0)], origin="learned") + sm.add_weighted_edges_from([(nodes[3], nodes[4], 0.7)], origin="expert") + + largest_subgraph = sm.get_largest_subgraph() + + assert set(largest_subgraph.edges.data("origin")) == { + (nodes[0], nodes[1], "unknown"), + (nodes[0], nodes[2], "learned"), + } + assert set(largest_subgraph.edges.data("weight")) == { + (nodes[0], nodes[1], 2.0), + (nodes[0], nodes[2], 1.0), + } + + +class TestDynamicStructureModelGetTargetSubgraph: + @pytest.mark.parametrize( + "target_node, test_input, expected", + [ + ( + DynamicStructureNode(1, 0), + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), + ], + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + ], + ), + ( + DynamicStructureNode(3, 0), + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + ], + ), + ( + DynamicStructureNode(7, 0), + [ + (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(5, 0), DynamicStructureNode(1, 0)), + ], + [ + (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + ), + ], + ) + def test_get_target_subgraph(self, target_node, test_input, expected): + """Should be able to return the subgraph with the specified node""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + subgraph = sm.get_target_subgraph(target_node) + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "target_node, test_input, expected", + [ + ( + DynamicStructureNode("a", 0), + [ + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), + (DynamicStructureNode("e", 0), DynamicStructureNode("f", 0)), + ], + [ + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), + ], + ), + ( + DynamicStructureNode("g", 0), + [ + (DynamicStructureNode("g", 0), DynamicStructureNode("h", 0)), + (DynamicStructureNode("g", 0), DynamicStructureNode("z", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), + (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), + (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), + ], + [ + (DynamicStructureNode("g", 0), DynamicStructureNode("h", 0)), + (DynamicStructureNode("g", 0), DynamicStructureNode("z", 0)), + ], + ), + ], + ) + def test_get_subgraph_string(self, target_node, test_input, expected): + """Should be able to return the subgraph with the specified node""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + subgraph = sm.get_target_subgraph(target_node) + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + @pytest.mark.parametrize( + "target_node, test_input", + [ + ( + DynamicStructureNode(7, 0), + [ + (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), + (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), + (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), + ], + ), + ( + DynamicStructureNode(1, 0), + [ + (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), + (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), + (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), + ], + ), + ], + ) + def test_node_not_in_graph(self, target_node, test_input): + """Should raise an error if the target_node is not found in the graph""" + + sm = DynamicStructureModel() + sm.add_edges_from(test_input) + + with pytest.raises( + NodeNotFound, + match=re.escape(f"Node {target_node} not found in the graph"), + ): + sm.get_target_subgraph(target_node) + + def test_isolates(self): + """Should return an isolated node""" + + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(7, 0), + ] + sm = DynamicStructureModel() + sm.add_nodes(nodes) + subgraph = sm.get_target_subgraph(DynamicStructureNode(1, 0)) + expected_graph = DynamicStructureModel() + expected_graph.add_node(DynamicStructureNode(1, 0)) + print(f"subgraph nodes {subgraph.nodes}\n") + print(f"expected nodes {expected_graph.nodes}") + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + def test_isolates_nodes_and_edges(self): + """Should be able to return the subgraph with the specified node""" + + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + DynamicStructureNode(7, 0), + DynamicStructureNode(8, 0), + DynamicStructureNode(9, 0), + ] + edges = [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[5], nodes[6]), + (nodes[4], nodes[5]), + ] + isolated_nodes = [nodes[7], nodes[8], nodes[9]] + sm = DynamicStructureModel() + sm.add_edges_from(edges) + sm.add_nodes(isolated_nodes) + subgraph = sm.get_target_subgraph(nodes[5]) + expected_edges = [(nodes[5], nodes[6]), (nodes[4], nodes[5])] + expected_graph = DynamicStructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) + + def test_different_origins_and_weights(self): + """The subgraph returned should still have the edge data preserved from the original graph""" + + sm = DynamicStructureModel() + nodes = [ + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(5, 0), + DynamicStructureNode(6, 0), + ] + + sm.add_weighted_edges_from([(nodes[0], nodes[1], 2.0)], origin="unknown") + sm.add_weighted_edges_from([(nodes[0], nodes[2], 1.0)], origin="learned") + sm.add_weighted_edges_from([(nodes[3], nodes[4], 0.7)], origin="expert") + + subgraph = sm.get_target_subgraph(nodes[1]) + + assert set(subgraph.edges.data("origin")) == { + (nodes[0], nodes[1], "unknown"), + (nodes[0], nodes[2], "learned"), + } + assert set(subgraph.edges.data("weight")) == { + (nodes[0], nodes[1], 2.0), + (nodes[0], nodes[2], 1.0), + } + + def test_instance_type(self): + """The subgraph returned should still be a DynamicStructureModel instance""" + + sm = DynamicStructureModel() + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(6, 0), + ] + sm.add_edges_from( + [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[4], nodes[5]), + ] + ) + subgraph = sm.get_target_subgraph(nodes[2]) + + assert isinstance(subgraph, DynamicStructureModel) + + def test_get_target_subgraph_twice(self): + """get_target_subgraph should be able to run more than once""" + + sm = DynamicStructureModel() + nodes = [ + DynamicStructureNode(0, 0), + DynamicStructureNode(1, 0), + DynamicStructureNode(2, 0), + DynamicStructureNode(3, 0), + DynamicStructureNode(4, 0), + DynamicStructureNode(6, 0), + ] + sm.add_edges_from( + [ + (nodes[0], nodes[1]), + (nodes[1], nodes[2]), + (nodes[1], nodes[3]), + (nodes[4], nodes[5]), + ] + ) + + subgraph = sm.get_target_subgraph(nodes[0]) + subgraph.remove_edge(nodes[0], nodes[1]) + subgraph = subgraph.get_target_subgraph(nodes[1]) + + expected_graph = DynamicStructureModel() + expected_edges = [(nodes[1], nodes[2]), (nodes[1], nodes[3])] + expected_graph.add_edges_from(expected_edges) + + assert set(subgraph.nodes) == set(expected_graph.nodes) + assert set(subgraph.edges) == set(expected_graph.edges) diff --git a/tests/structure/test_dynamicstructuremodel.py b/tests/structure/test_dynamicstructuremodel.py index ed4cf2b..a222743 100644 --- a/tests/structure/test_dynamicstructuremodel.py +++ b/tests/structure/test_dynamicstructuremodel.py @@ -609,442 +609,10 @@ def test_graph_with_no_edges(self): sm.add_nodes(nodes) sm.remove_edges_below_threshold(0.6) - assert set(sm.nodes) == set([node for node in nodes]) + assert set(sm.nodes) == set(nodes) assert set(sm.edges) == set() -class TestDynamicStructureModelGetLargestSubgraph: - @pytest.mark.parametrize( - "test_input, expected", - [ - ( - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), - (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), - ], - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), - ], - ), - ( - [ - (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), - (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), - (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), - ], - [ - (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), - (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), - ], - ), - # ([(0, 1), (1, 2), (1, 3), (4, 6)], [(0, 1), (1, 2), (1, 3)]), - # ([(3, 4), (3, 5), (7, 6)], [(3, 4), (3, 5)]), - ], - ) - def test_get_largest_subgraph(self, test_input, expected): - """Should be able to return the largest subgraph""" - - sm = DynamicStructureModel() - sm.add_edges_from(test_input) - largest_subgraph = sm.get_largest_subgraph() - - expected_graph = DynamicStructureModel() - expected_graph.add_edges_from(expected) - - assert set(largest_subgraph.nodes) == set(expected_graph.nodes) - assert set(largest_subgraph.edges) == set(expected_graph.edges) - - def test_more_than_one_largest(self): - """Return the first largest when there are more than one largest subgraph""" - - nodes = [ - DynamicStructureNode(0, 0), - DynamicStructureNode(1, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(4, 0), - DynamicStructureNode(5, 0), - ] - edges = [ - (nodes[0], nodes[1]), - (nodes[1], nodes[2]), - (nodes[3], nodes[4]), - (nodes[3], nodes[5]), - ] - sm = DynamicStructureModel() - sm.add_edges_from(edges) - largest_subgraph = sm.get_largest_subgraph() - - expected_edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] - expected_graph = DynamicStructureModel() - expected_graph.add_edges_from(expected_edges) - - assert set(largest_subgraph.nodes) == set(expected_graph.nodes) - assert set(largest_subgraph.edges) == set(expected_graph.edges) - - def test_empty(self): - """Should return None if the structure model is empty""" - - sm = DynamicStructureModel() - assert sm.get_largest_subgraph() is None - - def test_isolates(self): - """Should return None if the structure model only contains isolates""" - - sm = DynamicStructureModel() - nodes = [ - DynamicStructureNode(1, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(5, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(7, 0), - ] - sm.add_nodes(nodes) - assert sm.get_largest_subgraph() is None - - def test_isolates_nodes_and_edges(self): - """Should be able to return the largest subgraph""" - - nodes = [ - DynamicStructureNode(0, 0), - DynamicStructureNode(1, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(4, 0), - DynamicStructureNode(5, 0), - DynamicStructureNode(6, 0), - DynamicStructureNode(7, 0), - DynamicStructureNode(8, 0), - DynamicStructureNode(9, 0), - ] - edges = [ - (nodes[0], nodes[1]), - (nodes[1], nodes[2]), - (nodes[1], nodes[3]), - (nodes[5], nodes[6]), - ] - isolated_nodes = [nodes[7], nodes[8], nodes[9]] - sm = DynamicStructureModel() - sm.add_edges_from(edges) - sm.add_nodes(isolated_nodes) - largest_subgraph = sm.get_largest_subgraph() - - expected_edges = [ - (nodes[0], nodes[1]), - (nodes[1], nodes[2]), - (nodes[1], nodes[3]), - ] - expected_graph = DynamicStructureModel() - expected_graph.add_edges_from(expected_edges) - - assert set(largest_subgraph.nodes) == set(expected_graph.nodes) - assert set(largest_subgraph.edges) == set(expected_graph.edges) - - def test_different_origins_and_weights(self): - """The largest subgraph returned should still have the edge data preserved from the original graph""" - nodes = [ - DynamicStructureNode(1, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(5, 0), - DynamicStructureNode(6, 0), - ] - sm = DynamicStructureModel() - sm.add_weighted_edges_from([(nodes[0], nodes[1], 2.0)], origin="unknown") - sm.add_weighted_edges_from([(nodes[0], nodes[2], 1.0)], origin="learned") - sm.add_weighted_edges_from([(nodes[3], nodes[4], 0.7)], origin="expert") - - largest_subgraph = sm.get_largest_subgraph() - - assert set(largest_subgraph.edges.data("origin")) == { - (nodes[0], nodes[1], "unknown"), - (nodes[0], nodes[2], "learned"), - } - assert set(largest_subgraph.edges.data("weight")) == { - (nodes[0], nodes[1], 2.0), - (nodes[0], nodes[2], 1.0), - } - - -class TestDynamicStructureModelGetTargetSubgraph: - @pytest.mark.parametrize( - "target_node, test_input, expected", - [ - ( - DynamicStructureNode(1, 0), - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), - (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), - ], - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), - ], - ), - ( - DynamicStructureNode(3, 0), - [ - (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), - (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), - (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), - ], - [ - (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), - (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), - ], - ), - ( - DynamicStructureNode(7, 0), - [ - (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), - (DynamicStructureNode(2, 0), DynamicStructureNode(3, 0)), - (DynamicStructureNode(5, 0), DynamicStructureNode(1, 0)), - ], - [ - (DynamicStructureNode(7, 0), DynamicStructureNode(8, 0)), - (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), - ], - ), - ], - ) - def test_get_target_subgraph(self, target_node, test_input, expected): - """Should be able to return the subgraph with the specified node""" - - sm = DynamicStructureModel() - sm.add_edges_from(test_input) - subgraph = sm.get_target_subgraph(target_node) - expected_graph = DynamicStructureModel() - expected_graph.add_edges_from(expected) - - assert set(subgraph.nodes) == set(expected_graph.nodes) - assert set(subgraph.edges) == set(expected_graph.edges) - - @pytest.mark.parametrize( - "target_node, test_input, expected", - [ - ( - DynamicStructureNode("a", 0), - [ - (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), - (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), - (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), - (DynamicStructureNode("e", 0), DynamicStructureNode("f", 0)), - ], - [ - (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), - (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), - (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), - ], - ), - ( - DynamicStructureNode("g", 0), - [ - (DynamicStructureNode("g", 0), DynamicStructureNode("h", 0)), - (DynamicStructureNode("g", 0), DynamicStructureNode("z", 0)), - (DynamicStructureNode("a", 0), DynamicStructureNode("b", 0)), - (DynamicStructureNode("a", 0), DynamicStructureNode("c", 0)), - (DynamicStructureNode("c", 0), DynamicStructureNode("d", 0)), - ], - [ - (DynamicStructureNode("g", 0), DynamicStructureNode("h", 0)), - (DynamicStructureNode("g", 0), DynamicStructureNode("z", 0)), - ], - ), - ], - ) - def test_get_subgraph_string(self, target_node, test_input, expected): - """Should be able to return the subgraph with the specified node""" - - sm = DynamicStructureModel() - sm.add_edges_from(test_input) - subgraph = sm.get_target_subgraph(target_node) - expected_graph = DynamicStructureModel() - expected_graph.add_edges_from(expected) - - assert set(subgraph.nodes) == set(expected_graph.nodes) - assert set(subgraph.edges) == set(expected_graph.edges) - - @pytest.mark.parametrize( - "target_node, test_input", - [ - ( - DynamicStructureNode(7, 0), - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(3, 0)), - (DynamicStructureNode(4, 0), DynamicStructureNode(6, 0)), - ], - ), - ( - DynamicStructureNode(1, 0), - [ - (DynamicStructureNode(3, 0), DynamicStructureNode(4, 0)), - (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), - (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), - ], - ), - ], - ) - def test_node_not_in_graph(self, target_node, test_input): - """Should raise an error if the target_node is not found in the graph""" - - sm = DynamicStructureModel() - sm.add_edges_from(test_input) - - with pytest.raises( - NodeNotFound, - match=re.escape(f"Node {target_node} not found in the graph"), - ): - sm.get_target_subgraph(target_node) - - def test_isolates(self): - """Should return an isolated node""" - - nodes = [ - DynamicStructureNode(1, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(5, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(7, 0), - ] - sm = DynamicStructureModel() - sm.add_nodes(nodes) - subgraph = sm.get_target_subgraph(DynamicStructureNode(1, 0)) - expected_graph = DynamicStructureModel() - expected_graph.add_node(DynamicStructureNode(1, 0)) - print(f"subgraph nodes {subgraph.nodes}\n") - print(f"expected nodes {expected_graph.nodes}") - assert set(subgraph.nodes) == set(expected_graph.nodes) - assert set(subgraph.edges) == set(expected_graph.edges) - - def test_isolates_nodes_and_edges(self): - """Should be able to return the subgraph with the specified node""" - - nodes = [ - DynamicStructureNode(0, 0), - DynamicStructureNode(1, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(4, 0), - DynamicStructureNode(5, 0), - DynamicStructureNode(6, 0), - DynamicStructureNode(7, 0), - DynamicStructureNode(8, 0), - DynamicStructureNode(9, 0), - ] - edges = [ - (nodes[0], nodes[1]), - (nodes[1], nodes[2]), - (nodes[1], nodes[3]), - (nodes[5], nodes[6]), - (nodes[4], nodes[5]), - ] - isolated_nodes = [nodes[7], nodes[8], nodes[9]] - sm = DynamicStructureModel() - sm.add_edges_from(edges) - sm.add_nodes(isolated_nodes) - subgraph = sm.get_target_subgraph(nodes[5]) - expected_edges = [(nodes[5], nodes[6]), (nodes[4], nodes[5])] - expected_graph = DynamicStructureModel() - expected_graph.add_edges_from(expected_edges) - - assert set(subgraph.nodes) == set(expected_graph.nodes) - assert set(subgraph.edges) == set(expected_graph.edges) - - def test_different_origins_and_weights(self): - """The subgraph returned should still have the edge data preserved from the original graph""" - - sm = DynamicStructureModel() - nodes = [ - DynamicStructureNode(1, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(5, 0), - DynamicStructureNode(6, 0), - ] - - sm.add_weighted_edges_from([(nodes[0], nodes[1], 2.0)], origin="unknown") - sm.add_weighted_edges_from([(nodes[0], nodes[2], 1.0)], origin="learned") - sm.add_weighted_edges_from([(nodes[3], nodes[4], 0.7)], origin="expert") - - subgraph = sm.get_target_subgraph(nodes[1]) - - assert set(subgraph.edges.data("origin")) == { - (nodes[0], nodes[1], "unknown"), - (nodes[0], nodes[2], "learned"), - } - assert set(subgraph.edges.data("weight")) == { - (nodes[0], nodes[1], 2.0), - (nodes[0], nodes[2], 1.0), - } - - def test_instance_type(self): - """The subgraph returned should still be a DynamicStructureModel instance""" - - sm = DynamicStructureModel() - nodes = [ - DynamicStructureNode(0, 0), - DynamicStructureNode(1, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(4, 0), - DynamicStructureNode(6, 0), - ] - sm.add_edges_from( - [ - (nodes[0], nodes[1]), - (nodes[1], nodes[2]), - (nodes[1], nodes[3]), - (nodes[4], nodes[5]), - ] - ) - subgraph = sm.get_target_subgraph(nodes[2]) - - assert isinstance(subgraph, DynamicStructureModel) - - def test_get_target_subgraph_twice(self): - """get_target_subgraph should be able to run more than once""" - - sm = DynamicStructureModel() - nodes = [ - DynamicStructureNode(0, 0), - DynamicStructureNode(1, 0), - DynamicStructureNode(2, 0), - DynamicStructureNode(3, 0), - DynamicStructureNode(4, 0), - DynamicStructureNode(6, 0), - ] - sm.add_edges_from( - [ - (nodes[0], nodes[1]), - (nodes[1], nodes[2]), - (nodes[1], nodes[3]), - (nodes[4], nodes[5]), - ] - ) - - subgraph = sm.get_target_subgraph(nodes[0]) - subgraph.remove_edge(nodes[0], nodes[1]) - subgraph = subgraph.get_target_subgraph(nodes[1]) - - expected_graph = DynamicStructureModel() - expected_edges = [(nodes[1], nodes[2]), (nodes[1], nodes[3])] - expected_graph.add_edges_from(expected_edges) - - assert set(subgraph.nodes) == set(expected_graph.nodes) - assert set(subgraph.edges) == set(expected_graph.edges) - - class TestDynamicStructureModelGetMarkovBlanket: @pytest.mark.parametrize( "target_node, test_input, expected", @@ -1257,7 +825,7 @@ def test_get_markov_blanket_string(self, target_node, test_input, expected): (DynamicStructureNode(3, 0), DynamicStructureNode(5, 0)), (DynamicStructureNode(7, 0), DynamicStructureNode(6, 0)), ], - ) + ), ], ) def test_node_not_in_graph(self, target_node, test_input): @@ -1348,308 +916,3 @@ def test_instance_type(self): subgraph = sm.get_markov_blanket(nodes[2]) assert isinstance(subgraph, DynamicStructureModel) - - -class TestDynamicStructureModelEdgeCoercion: - def test_edge_not_tuple(self): - edges = [((1, 0), (3, 0), 0.5), 6] - sm = DynamicStructureModel() - - with pytest.raises( - TypeError, - match=re.escape( - f"Edges must be tuples containing 2 or 3 elements, received {edges}" - ), - ): - sm.add_edges_from(edges) - - def test_multi_edge_not_dsn(self): - edges = [((0, 0), (1, 0)), ((1, 0), (2, 0))] - sm = DynamicStructureModel() - sm.add_edges_from(edges) - - expected_edges = [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - ] - expected_graph = DynamicStructureModel() - expected_graph.add_edges_from(expected_edges) - - assert set(sm.nodes) == set(expected_graph.nodes) - assert set(sm.edges) == set(expected_graph.edges) - - def test_weighted_multi_edge_not_dsn(self): - edges = [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7)] - sm = DynamicStructureModel() - sm.add_weighted_edges_from(edges) - - expected_edges = [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), - ] - expected_graph = DynamicStructureModel() - expected_graph.add_weighted_edges_from(expected_edges) - - assert set(sm.nodes) == set(expected_graph.nodes) - assert set(sm.edges) == set(expected_graph.edges) - - @pytest.mark.parametrize( - "input_edges, expected_edges", - [ - ( - [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7)], - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), - ], - ), - ( - [((0, 0), (1, 0)), ((1, 0), (2, 0))], - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - ], - ), - ], - ) - def test_multi_edge_dsn(self, input_edges, expected_edges): - sm = DynamicStructureModel() - weighted = len(input_edges[0]) == 3 - if not weighted: - sm.add_edges_from(input_edges) - else: - sm.add_weighted_edges_from(input_edges) - - expected_graph = DynamicStructureModel() - if not weighted: - expected_graph.add_edges_from(expected_edges) - else: - expected_graph.add_weighted_edges_from(expected_edges) - - assert set(sm.nodes) == set(expected_graph.nodes) - assert set(sm.edges) == set(expected_graph.edges) - - def test_node_not_tuple(self): - edges = [((1, 0), (3, 0), 0.5), ((1, 0), 3, 0.7)] - sm = DynamicStructureModel() - - with pytest.raises( - TypeError, - match=re.escape( - f"Nodes in {edges[1]} must be tuples with node name and time step" - ), - ): - sm.add_edges_from(edges) - - @pytest.mark.parametrize( - "input_edges, expected_edges", - [ - ( - [ - (DynamicStructureNode(0, 0), (1, 0), 0.5), - ((1, 0), DynamicStructureNode(2, 0), 0.7), - ], - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0), 0.7), - ], - ), - ( - [ - (DynamicStructureNode(0, 0), (1, 0)), - ((1, 0), DynamicStructureNode(2, 0)), - ], - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - ], - ), - ( - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - ((1, 0), DynamicStructureNode(2, 0)), - ], - [ - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(1, 0), DynamicStructureNode(2, 0)), - ], - ), - ], - ) - def test_multi_edge_one_dsn(self, input_edges, expected_edges): - sm = DynamicStructureModel() - weighted = len(input_edges[0]) == 3 - if not weighted: - sm.add_edges_from(input_edges) - else: - sm.add_weighted_edges_from(input_edges) - - expected_graph = DynamicStructureModel() - if not weighted: - expected_graph.add_edges_from(expected_edges) - else: - expected_graph.add_weighted_edges_from(expected_edges) - - assert set(sm.nodes) == set(expected_graph.nodes) - assert set(sm.edges) == set(expected_graph.edges) - - def test_multi_edge_bad_tuple(self): - edges = [((0, 0), (1, 0), 0.5), ((1, 0), (2, 0), 0.7, 0.8)] - sm = DynamicStructureModel() - with pytest.raises( - TypeError, - match=re.escape( - f"Argument {edges[1]} must be a tuple containing 2 or 3 elements" - ), - ): - sm.add_weighted_edges_from(edges) - - # def test_single_edge_not_tuple(self): - # edge = 6 - # sm = DynamicStructureModel() - - # with pytest.raises( - # TypeError, - # match=re.escape(f"Edges must be tuples containing 2 or 3 elements, received {edge}"), - # ): - # sm.add_edge(edge) - - def test_single_edge_node_not_tuple(self): - u = (1, 0) - v = 3 - sm = DynamicStructureModel() - - with pytest.raises( - TypeError, - match=re.escape( - f"Nodes in {(u, v)} must be tuples with node name and time step" - ), - ): - sm.add_edge(u, v) - - @pytest.mark.parametrize( - "input_edge, expected_edge", - [ - ( - ((0, 0), (1, 0), 0.5), - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), - ), - ( - ((0, 0), (1, 0)), - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - ), - ( - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - ), - ( - (DynamicStructureNode(0, 0), (1, 0)), - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - ), - ( - ((0, 0), DynamicStructureNode(1, 0)), - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)), - ), - ( - (DynamicStructureNode(0, 0), (1, 0), 0.5), - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), - ), - ( - ((0, 0), DynamicStructureNode(1, 0), 0.5), - (DynamicStructureNode(0, 0), DynamicStructureNode(1, 0), 0.5), - ), - ], - ) - def test_single_edge_dsn(self, input_edge, expected_edge): - sm = DynamicStructureModel() - weighted = len(input_edge) == 3 - if not weighted: - sm.add_edge(input_edge[0], input_edge[1]) - else: - sm.add_weighted_edges_from(input_edge) - - expected_graph = DynamicStructureModel() - if not weighted: - expected_graph.add_edge(expected_edge[0], expected_edge[1]) - else: - expected_graph.add_weighted_edges_from(expected_edge) - - assert set(sm.nodes) == set(expected_graph.nodes) - assert set(sm.edges) == set(expected_graph.edges) - - def test_single_edge_bad_tuple(self): - edge = ((1, 0), (2, 0), 0.7, 0.8) - sm = DynamicStructureModel() - with pytest.raises( - TypeError, - match=re.escape( - f"Argument {edge} must be either a DynamicStructureNode or tuple containing 2 or 3 elements" - ), - ): - sm.add_weighted_edges_from(edge) - - -class TestDynamicStructureModelNodeCoercion: - @pytest.mark.parametrize( - "input_nodes, expected_nodes", - [ - ( - [(0, 0), (1, 0)], - [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], - ), - ( - [DynamicStructureNode(0, 0), (1, 0)], - [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], - ), - ( - (DynamicStructureNode(n, 0) for n in range(2)), - [DynamicStructureNode(0, 0), DynamicStructureNode(1, 0)], - ), - ], - ) - def test_multi_node(self, input_nodes, expected_nodes): - sm = DynamicStructureModel() - sm.add_nodes(input_nodes) - - expected_graph = DynamicStructureModel() - expected_graph.add_nodes(expected_nodes) - - assert set(sm.nodes) == set(expected_graph.nodes) - - @pytest.mark.parametrize( - "input_node, expected_node", - [ - ((0, 0), DynamicStructureNode(0, 0)), - (DynamicStructureNode(0, 0), DynamicStructureNode(0, 0)), - ], - ) - def test_single_node(self, input_node, expected_node): - sm = DynamicStructureModel() - sm.add_nodes(input_node) - - expected_graph = DynamicStructureModel() - expected_graph.add_nodes(expected_node) - - assert set(sm.nodes) == set(expected_graph.nodes) - - def test_multi_node_bad_tuple(self): - nodes = [(0, 0), (1, 0, 1)] - sm = DynamicStructureModel() - with pytest.raises( - TypeError, - match=re.escape( - f"Argument {nodes[1]} must be either a DynamicStructureNode or tuple containing 2 elements" - ), - ): - sm.add_nodes(nodes) - - def test_single_node_bad_tuple(self): - node = (1, 0, 1) - sm = DynamicStructureModel() - with pytest.raises( - TypeError, - match=re.escape( - f"Argument {node} must be either a DynamicStructureNode or tuple containing 2 elements" - ), - ): - sm.add_node(node) From acd9dfe1e92bfd5a59ffd2535c7344e2a51075a2 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Mon, 17 Oct 2022 20:51:31 -0500 Subject: [PATCH 10/13] fix linter error --- tests/structure/test_dsn_subgraph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/structure/test_dsn_subgraph.py b/tests/structure/test_dsn_subgraph.py index aee3e08..ef3d9fc 100644 --- a/tests/structure/test_dsn_subgraph.py +++ b/tests/structure/test_dsn_subgraph.py @@ -33,6 +33,7 @@ from causalnex.structure import DynamicStructureModel, DynamicStructureNode + class TestDynamicStructureModelGetLargestSubgraph: @pytest.mark.parametrize( "test_input, expected", From 8fcf35543a2d91fffcdb1c6293c2deb1b2207ef3 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Mon, 17 Oct 2022 21:12:06 -0500 Subject: [PATCH 11/13] test doc fix --- causalnex/structure/structuremodel.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index e473c77..7151c85 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -354,6 +354,9 @@ def get_node_name(self): return f"{self.node}_lag{self.time_step}" +DynamicStructureNode.__new__.__module__ = __name__ + + def check_collection_type(c): """ Check if data structure is a collection From 2143e526f6fd4fd30d8dd254c4d9f142c5f01c1f Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Mon, 17 Oct 2022 21:24:28 -0500 Subject: [PATCH 12/13] add comment --- causalnex/structure/structuremodel.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 7151c85..19bf4fc 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -354,6 +354,11 @@ def get_node_name(self): return f"{self.node}_lag{self.time_step}" +""" +There is a problem with `sphinx-autodoc-typehints` and NamedTuples. This is needed +to prevent https://github.com/tox-dev/sphinx-autodoc-typehints/issues/68 +Fix here https://github.com/sphinx-doc/sphinx/issues/6636#issuecomment-608083353 +""" DynamicStructureNode.__new__.__module__ = __name__ From 5c8505b77a1214b1603153888ec718e54ee67190 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Mon, 17 Oct 2022 21:38:00 -0500 Subject: [PATCH 13/13] linter fix --- causalnex/structure/structuremodel.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py index 19bf4fc..10460d6 100644 --- a/causalnex/structure/structuremodel.py +++ b/causalnex/structure/structuremodel.py @@ -354,11 +354,9 @@ def get_node_name(self): return f"{self.node}_lag{self.time_step}" -""" -There is a problem with `sphinx-autodoc-typehints` and NamedTuples. This is needed -to prevent https://github.com/tox-dev/sphinx-autodoc-typehints/issues/68 -Fix here https://github.com/sphinx-doc/sphinx/issues/6636#issuecomment-608083353 -""" +# There is a problem with `sphinx-autodoc-typehints` and NamedTuples. This is needed +# to prevent https://github.com/tox-dev/sphinx-autodoc-typehints/issues/68 +# Fix here https://github.com/sphinx-doc/sphinx/issues/6636#issuecomment-608083353 DynamicStructureNode.__new__.__module__ = __name__