Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/dbn #174

Draft
wants to merge 14 commits into
base: develop
Choose a base branch
from
4 changes: 3 additions & 1 deletion causalnex/structure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
"data_generators",
"DAGRegressor",
"DAGClassifier",
"DynamicStructureModel",
"DynamicStructureNode",
]

from .pytorch.sklearn import DAGClassifier, DAGRegressor
from .structuremodel import StructureModel
from .structuremodel import DynamicStructureModel, DynamicStructureNode, StructureModel
44 changes: 17 additions & 27 deletions causalnex/structure/dynotears.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
"""

import warnings
from typing import Dict, List, Tuple, Union
from typing import List, Tuple, Union

import numpy as np
import pandas as pd
import scipy.linalg as slin
import scipy.optimize as sopt

from causalnex.structure import StructureModel
from causalnex.structure import DynamicStructureModel, DynamicStructureNode
from causalnex.structure.transformers import DynamicDataTransformer


Expand All @@ -53,7 +53,7 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments
tabu_edges: List[Tuple[int, int, int]] = None,
tabu_parent_nodes: List[int] = None,
tabu_child_nodes: List[int] = None,
) -> StructureModel:
) -> DynamicStructureModel:
liam-adams marked this conversation as resolved.
Show resolved Hide resolved
"""
Learn the graph structure of a Dynamic Bayesian Network describing conditional dependencies between variables in
data. The input data is a time series or a list of realisations of a same time series.
Expand Down Expand Up @@ -122,15 +122,19 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments
tabu_child_nodes,
)

sm = StructureModel()
sm.add_nodes_from(
[f"{var}_lag{l_val}" for var in col_idx.keys() for l_val in range(p + 1)]
sm = DynamicStructureModel()
sm.add_nodes(
[
DynamicStructureNode(var, l_val)
for var in col_idx.keys()
for l_val in range(p + 1)
]
)
sm.add_weighted_edges_from(
[
(
_format_name_from_pandas(idx_col, u),
_format_name_from_pandas(idx_col, v),
DynamicStructureNode(idx_col[int(u[0])], u[-1]),
DynamicStructureNode(idx_col[int(v[0])], v[-1]),
w,
)
for u, v, w in g.edges.data("weight")
Expand All @@ -141,20 +145,6 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments
return sm


def _format_name_from_pandas(idx_col: Dict[int, str], from_numpy_node: str) -> str:
"""
Helper function for `from_pandas_dynamic`. converts a node from the `from_numpy_dynamic` format to the `from_pandas`
format
Args:
idx_col: map from variable to intdex
from_numpy_node: nodes in the structure model output by `from_numpy_dynamic`.
Returns:
nodes in from_pandas_dynamic format
"""
idx, lag_val = from_numpy_node.split("_lag")
return f"{idx_col[int(idx)]}_lag{lag_val}"


def from_numpy_dynamic( # pylint: disable=too-many-arguments
X: np.ndarray,
Xlags: np.ndarray,
Expand All @@ -166,7 +156,7 @@ def from_numpy_dynamic( # pylint: disable=too-many-arguments
tabu_edges: List[Tuple[int, int, int]] = None,
tabu_parent_nodes: List[int] = None,
tabu_child_nodes: List[int] = None,
) -> StructureModel:
) -> DynamicStructureModel:
"""
Learn the graph structure of a Dynamic Bayesian Network describing conditional dependencies between variables in
data. The input data is time series data present in numpy arrays X and Xlags.
Expand Down Expand Up @@ -254,7 +244,7 @@ def from_numpy_dynamic( # pylint: disable=too-many-arguments

def _matrices_to_structure_model(
w_est: np.ndarray, a_est: np.ndarray
) -> StructureModel:
) -> DynamicStructureModel:
"""
Converts the matrices output by dynotears (W and A) into a StructureModel
We use the following convention:
Expand All @@ -268,13 +258,13 @@ def _matrices_to_structure_model(
StructureModel representing the structure learnt

"""
sm = StructureModel()
sm = DynamicStructureModel()
lag_cols = [
f"{var}_lag{l_val}"
DynamicStructureNode(var, l_val)
for l_val in range(1 + (a_est.shape[0] // a_est.shape[1]))
for var in range(a_est.shape[1])
]
sm.add_nodes_from(lag_cols)
sm.add_nodes(lag_cols)
sm.add_edges_from(
[
(lag_cols[i], lag_cols[j], dict(weight=w_est[i, j]))
Expand Down
Loading