Skip to content

Commit

Permalink
Release/0.7.0 (mckinsey#57)
Browse files Browse the repository at this point in the history
* Added plottting tutorial to the documentation
* Updated `viz.draw` syntax in tutorial notebooks
* Bugfix on notears lasso (`from_numpy_lasso` and `from_pandas_lasso`) where the non-negativity constraint was not being set
* Added DAG-based synthetic data generator for mixed types (binary, categorical, continuous) using a linear SEM approach.
* Unpinned some requirements
  • Loading branch information
qbphilip authored May 28, 2020
1 parent f878706 commit 595907c
Show file tree
Hide file tree
Showing 26 changed files with 3,930 additions and 48 deletions.
11 changes: 11 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Upcoming release


# Release 0.7.0

* Added plottting tutorial to the documentation
* Updated `viz.draw` syntax in tutorial notebooks
* Bugfix on notears lasso (`from_numpy_lasso` and `from_pandas_lasso`) where the non-negativity constraint was not being set
* Added DAG-based synthetic data generator for mixed types (binary, categorical, continuous) using a linear SEM approach.
* Unpinned some requirements

# Release 0.6.0

* support for newer versions of scikit-learn
Expand Down
2 changes: 1 addition & 1 deletion causalnex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
causalnex toolkit for causal reasoning (Bayesian Networks / Inference)
"""

__version__ = "0.6.0"
__version__ = "0.7.0"

__all__ = ["structure", "discretiser", "evaluation", "inference", "network", "plots"]
9 changes: 6 additions & 3 deletions causalnex/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"""

import re
from copy import deepcopy
from typing import Dict, Hashable, List, Set, Tuple

import networkx as nx
Expand Down Expand Up @@ -474,11 +475,13 @@ def _predict_from_incomplete_data(
A dataframe of predictions, containing a single column name {node}_prediction.
"""

transformed_data = data.copy(deep=True) # type: pd.DataFrame
transformed_data = deepcopy(data) # type: pd.DataFrame
self._state_to_index(transformed_data)

# transformed_data.is_copy()

# pgmpy will predict all missing data, so drop column we want to predict
transformed_data.drop(node, axis=1, inplace=True)
transformed_data = transformed_data.drop(columns=[node])

predictions = self._model.predict(transformed_data)[[node]]

Expand Down Expand Up @@ -560,7 +563,7 @@ def _predict_probability_from_incomplete_data(
self._state_to_index(transformed_data)

# pgmpy will predict all missing data, so drop column we want to predict
transformed_data.drop(node, axis=1, inplace=True)
transformed_data = transformed_data.drop(columns=[node])

probability = self._model.predict_probability(
transformed_data
Expand Down
320 changes: 320 additions & 0 deletions causalnex/structure/categorical_variable_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
# Copyright 2019-2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains utility functions to handle categorical features. While the statistical
dependencies are defined on the node level, categorical variables require an
expansion to "one-hot" encoding for numeric analysis.
"""
import re
from collections import OrderedDict
from itertools import chain
from typing import Dict, Hashable, Iterator, List, Optional, Set, Union

from networkx.classes.reportviews import NodeView


class VariableFeatureMapper:
"""
When expanding the columns from variables to features, this class keeps
track of the 1:m mapping between the collapsed and expanded columns.
Args:
schema: A dictionary mapping a variable (can be any hashable object) to
a data type. Permissible data types are stored in the class
attribute ``PERMISSIBLE_TYPES``.
"""

PERMISSIBLE_TYPES = {"binary", "categorical", "continuous"}
EXPANDABLE_TYPE = "categorical"

def __init__(self, schema: Dict[Hashable, str]):
# 0. get all variables
self.variable_type_dict = OrderedDict(
[
(x, [k for k, v in schema.items() if x in v])
for x in self.PERMISSIBLE_TYPES
]
)

# 1. split categorical variable into features
# dictionary: categorical variable to cardinality
cat_card_dict = OrderedDict(
[
(cat_var, int(v.split(":")[1]))
for cat_var, v in schema.items()
if self.EXPANDABLE_TYPE in v
]
)

# dictionary: categorical feature to variable (C:1 mapping)
self._cat_fte_var_dict = OrderedDict(
[
("{}_{}".format(cat_var, i), cat_var)
for cat_var, card in cat_card_dict.items()
for i in range(card)
]
)
cat_feature_list = list(self._cat_fte_var_dict.keys())

# we put them together with the cont + binayr in a feature list
self.feature_list = (
self.variable_type_dict["binary"]
+ self.variable_type_dict["continuous"]
+ cat_feature_list
)

# 2. we assign an index to each feature
# dictionary: feature to index
self._fte_index_dict = {fte: ix for ix, fte in enumerate(self.feature_list)}

# 3. map a feature to all corresponding (expanded) columns
# dictionary: variable to indices of all corresponding features
self.var_indices_dict = {
var: [self._fte_index_dict[var]]
for var in self.variable_type_dict["continuous"]
+ self.variable_type_dict["binary"]
}
self.var_indices_dict.update(
{
k: [
self._fte_index_dict[fte]
for fte, var in self._cat_fte_var_dict.items()
if var == k
]
for k in self.variable_type_dict["categorical"]
}
)

@property
def variable_list(self) -> List[Hashable]:
"""
Returns a list of all variables/nodes.
"""
return list(chain.from_iterable(self.variable_type_dict.values()))

def get_var_of_type(self, data_type: str) -> List[Hashable]:
"""
Returns all variables/nodes corresponding to the provided data type
Args:
data_type: Variable type.
Returns:
List of variables
Raises:
ValueError: if the variable type is not supported
"""
if data_type not in self.PERMISSIBLE_TYPES:
raise ValueError(
"Unsupported variable type {}, ".format(data_type)
+ "supported data types are: {}".format(
", ".join(valid_type for valid_type in self.PERMISSIBLE_TYPES)
)
)
return self.variable_type_dict[data_type]

def is_var_of_type(self, var: Hashable, data_type: str) -> bool:
"""
Checks whether the variable/node is of the provided data type
Args:
var: Variable/node
data_type: Supported data type
Returns:
Boolean flag
Raises:
ValueError: if the variable type is not supported
"""
return var in self.get_var_of_type(data_type=data_type)

def get_categorical_indices(self) -> List[List[int]]:
"""
Returns a list of lists that includes all categorical feature indices
for all categorical variables.
"""
return [
self.var_indices_dict[var] for var in self.variable_type_dict["categorical"]
]

def get_indices(
self,
var: Union[Hashable, List[Hashable], Set[Hashable], Iterator],
squeeze: bool = False,
) -> Union[int, List[int]]:
"""
Returns the indices for a variable or list of variables.
Args:
var: A variable/node.
squeeze: No effect if either a list,set,"dict_iterable" (e.g. from
``graph.predecessor(var)``) is provided or var is a
categorical value.
Returns:
A list of indices. For binary and continuous variables this will be
a list of length one. If squeeze is True, returns the index outside
of a list.
Raises:
ValueError: if an unsupported variable object is provided.
"""
if var in self.variable_list:
if squeeze and var not in self.get_var_of_type("categorical"):
return self.var_indices_dict[var][0]
return self.var_indices_dict[var]
if isinstance(var, (list, set)):
return [ix for v in var for ix in self.get_indices(v)]
if hasattr(var, "__next__"):
# Deals with DiGraph.predecessors's "dict_keyiterator" and similar
# Iterators
return [ix for v in list(var) for ix in self.get_indices(v)]
raise ValueError(
"Provide a valid variable name, a set/list/Iterator of variable "
"names. Other iterables are not supported."
)

def get_feature_index(self, feature: Hashable) -> int:
"""
Gets the feature index.
Returns:
The index of a feature.
Raises:
ValueError: if a categorical variable instead of a categorical
"one-hot" feature is provided.
"""
if feature in self.variable_list and feature not in self.feature_list:
raise ValueError(
"Input is not a feature, use ``get_indices`` to get the indices "
"associated for a variable/node."
)
return self._fte_index_dict[feature]

def get_feature_names(
self, var: Optional[Hashable] = None
) -> Union[Hashable, List[Hashable]]:
"""
Get the feature name(s) corresponding to the variable. If none provided,
returns all features.
Returns:
Returns all feature names corresponding to a variable/node.
- For binary and continuous variables, this is the variable/node
itself.
- For categorical variables, it returns a List of Hashables
"""
if var is None:
return self.feature_list
if var not in self.variable_type_dict["categorical"]:
return var
return [k for k, v in self._cat_fte_var_dict.items() if v == var]

@property
def n_variables(self):
"""
Returns:
The number of variables
"""
return len(self.variable_list)

@property
def n_features(self):
"""
The number of features. If the schema only has binary and continuous
variables, this is equal to ``n_variables``.
"""
return len(self.feature_list)


def validate_schema(
nodes: Union[List[Hashable], Set[Hashable], NodeView],
default_type: str = "continuous",
schema: Optional[Dict[Hashable, str]] = None,
) -> Dict:
"""
Verifies category type and uses default data type for unspecified variables.
Variables in the schema but not in the node list are ignored. The ``nodes``
object is taken as the ground truth for variables to process.
Args:
nodes: All variables that should have a schema.
schema: Dictionary mapping a variable to a data type.
default_type: Allowed data types are 'binary', 'continuous',
'categorical:X' where X stands for the cardinality of the category.
Leading zeros are not allowed for the cardinality.
Returns:
Schema with missing type imputed by ``default_type``
Raises:
ValueError: for unknown data type
ValueError: for missing cardinality for categorical variables
"""
if not any(x in default_type for x in VariableFeatureMapper.PERMISSIBLE_TYPES):
raise ValueError(
"Unknown default data type. Supported data types are {}".format(
", ".join(VariableFeatureMapper.PERMISSIBLE_TYPES)
)
)

schema = {} if schema is None else schema
# # add default data type to missing nodes
schema = {k: schema.get(k, default_type) for k in nodes}

# verify if the data type is supported
if not all(
any(t in x for t in VariableFeatureMapper.PERMISSIBLE_TYPES)
for x in schema.values()
):
raise ValueError(
"Unknown data type for variable(s) {}, ".format(
[
k
for k, v in schema.items()
if v not in VariableFeatureMapper.PERMISSIBLE_TYPES
]
)
+ "Supported data types are {}".format(
", ".join(VariableFeatureMapper.PERMISSIBLE_TYPES)
),
)

missing_cardinality = {
k
for k, v in schema.items()
if "categorical" in v and re.match(r"^categorical:[1-9]+[0-9]*", v) is None
}

if missing_cardinality:
raise ValueError(
"Missing cardinality for categorical variable(s) {}".format(
missing_cardinality
)
+ " For example, specify the data type as `categorical:3` for a "
+ "3-class categorical feature. Leading zeros are not allowed."
)
return schema
Loading

0 comments on commit 595907c

Please sign in to comment.