Skip to content

Commit

Permalink
initial implementation of causal graphs #68
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenlujpl committed Oct 8, 2021
1 parent 70c6add commit ea21896
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 15 deletions.
145 changes: 133 additions & 12 deletions dora_exp_pipeline/dora_results_organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
# June 3, 2021

import os
import sys
from six import add_metaclass
from abc import ABCMeta, abstractmethod
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
from sklearn.cluster import KMeans
from sklearn_som.som import SOM
sys.path.append("/Users/youlu/Desktop/dora/work/causal_graph/fges-py")
import SEMScore
import fges
import knowledge
import networkx as nx


METHOD_POOL = []
Expand Down Expand Up @@ -52,14 +58,14 @@ def can_run(self, loader_name):
else:
return False

def run(self, data_ids, dts_scores, dts_sels, data_to_score,
def run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n, **params):
self._run(data_ids, dts_scores, dts_sels, data_to_score,
self._run(data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n,
**params)

@abstractmethod
def _run(self, data_ids, dts_scores, dts_sels, data_to_score,
def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, logger, seed, top_n, **params):
raise RuntimeError('This function must be implemented in a child class')

Expand All @@ -68,7 +74,7 @@ class SaveScoresCSV(ResultsOrganization):
def __init__(self):
super(SaveScoresCSV, self).__init__('save_scores')

def _run(self, data_ids, dts_scores, dts_sels, data_to_score,
def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n):
if not os.path.exists(out_dir):
os.mkdir(out_dir)
Expand All @@ -92,7 +98,7 @@ class SaveComparisonPlot(ResultsOrganization):
def __init__(self):
super(SaveComparisonPlot, self).__init__('comparison_plot')

def _run(self, data_ids, dts_scores, dts_sels, data_to_score,
def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n, validation_dir):
if(not(os.path.exists(out_dir))):
os.makedirs(out_dir)
Expand Down Expand Up @@ -142,8 +148,9 @@ class KmeansCluster(ResultsOrganization):
def __init__(self):
super(KmeansCluster, self).__init__('kmeans')

def _run(self, data_ids, dts_scores, dts_sels, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n, n_clusters):
def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n, n_clusters,
causal_graph):
if not os.path.exists(out_dir):
os.mkdir(out_dir)
if logger:
Expand All @@ -169,6 +176,10 @@ def _run(self, data_ids, dts_scores, dts_sels, data_to_score,

out_file.close()

if causal_graph:
generate_causal_graphs(data_to_fit, data_to_cluster, groups, out_dir,
logger, seed)


kmeans_cluster = KmeansCluster()
register_org_method(kmeans_cluster)
Expand All @@ -178,8 +189,9 @@ class SOMCluster(ResultsOrganization):
def __init__(self):
super(SOMCluster, self).__init__('som')

def _run(self, data_ids, dts_scores, dts_sels, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n, n_clusters):
def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n, n_clusters,
causal_graph):
if not os.path.exists(out_dir):
os.mkdir(out_dir)
if logger:
Expand All @@ -202,16 +214,125 @@ def _run(self, data_ids, dts_scores, dts_sels, data_to_score,

out_file.close()

if causal_graph:
generate_causal_graphs(data_to_fit, data_to_cluster, groups, out_dir,
logger, seed)


som_cluster = SOMCluster()
register_org_method(som_cluster)


class NodeBlock:
"""
members is a list of node names
order is a single integer. Low-->high is left-->right
"""
def __init__(self, members=None, order=None):
self.members = members
self.order = order


def generate_causal_graphs(data_to_fit, data_to_cluster, cluster_groups, out_dir,
logger, seed):
causal_tags = ['feature-%d' % col for col in range(len(data_to_cluster[0]))]
causal_tags = causal_tags + ['outlier']

if len(causal_tags) > 20:
if logger:
logger.text('Can not generate causal graphs for data sets with '
'more than 20 features.')

unique_groups = np.unique(cluster_groups)
data_to_fit = np.append(data_to_fit, np.zeros((len(data_to_fit), 1)), axis=1)
for group_label in unique_groups:
in_group = cluster_groups == group_label
outliers = data_to_cluster[in_group]
outliers = np.append(outliers, np.ones((len(outliers), 1)), axis=1)
data = np.vstack((data_to_fit, outliers))

# Define knowledge that forbids any connections from features to outlier
ken = knowledge.Knowledge()
block1 = ['feature-%d' % col for col in range(len(data_to_cluster[0]))]
block2 = ['outlier']
for i, i_label in enumerate(causal_tags):
for j, j_label in enumerate(causal_tags):
if (i_label in block1) & (j_label in block2):
ken.set_forbidden(i, j)

# Generate causal graphs
variables = list(range(len(data[0])))
score = SEMScore.SEMBicScore(2, dataset=data)
cs = fges.FGES(variables, score, knowledge=ken)
cs.search()
graph = cs.graph

# Assign names to graph nodes
node_labels = dict()
for idx, tag in enumerate(causal_tags):
node_labels.update({idx: tag})
graph = nx.relabel_nodes(graph, node_labels)

# Save graph
out_file = '%s/causal_graph_cluster_%d' % (out_dir, group_label)

block1_nodes = NodeBlock(members=causal_tags[:-1], order=2)
block2_nodes = NodeBlock(members=['outlier'], order=1)
pos = arrange_nodes([block1_nodes, block2_nodes])
outlier_pos = pos['outlier']
outlier_pos = (outlier_pos[0], -10)
pos['outlier'] = outlier_pos
nx.draw(graph, with_labels=True, pos=pos)
x_values, y_values = zip(*pos.values())
xmin, xmax = min(x_values), max(x_values)
xmargin = 0.25 * (xmax - xmin)
plt.xlim(xmin - xmargin, xmax + xmargin)

plt.savefig(out_file)
plt.clf()


def arrange_nodes(blocks, labelheight=2, colwidth=20, blocksep=20,
bottom_margin=20):
"""
takes a list-like of 'blocks' objects.
arranges nodes into a hopefully-pleasing shape according to block membership.
Note that this expects each node belongs to only a single block. No
guarantees if that's not true.
"""
# Let's assume that the node labels are arranged into blocks.
# Within a block, they need to spaced widely enough.
# Let's also assume that each block has a number indicating its order of precedence.
block_orders = np.array([block.order for block in blocks])
orders = np.sort(np.unique(block_orders.copy()))
pos = {}
labels = []
for iorder, thisorder in enumerate(orders):
these = np.arange(len(block_orders))[block_orders == thisorder]
left_position = iorder * colwidth
order_heights = []
for iblock, blockindex in enumerate(these):
for imember, member in enumerate(blocks[blockindex].members):
height = iblock*blocksep + imember*2*labelheight + bottom_margin
pos.update({member: (left_position, height)})
labels.append(member)
order_heights.append(height)
meanheight = np.mean(order_heights)
for iblock, blockindex in enumerate(these):
for imember, member in enumerate(blocks[blockindex].members):
height = iblock*blocksep + imember*2*labelheight + bottom_margin
pos.update({member: (left_position, height - meanheight)})
labels.append(member)
order_heights.append(height)

return pos


class ReshapeRaster(ResultsOrganization):
def __init__(self):
super(ReshapeRaster, self).__init__('reshape_raster')

def _run(self, data_ids, dts_scores, dts_sels, data_to_score,
def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n,
raster_path, data_format, patch_size, colormap):
if not os.path.exists(out_dir):
Expand Down Expand Up @@ -273,8 +394,8 @@ class SaveHistogram(ResultsOrganization):
def __init__(self):
super(SaveHistogram, self).__init__('histogram')

def _run(self, data_ids, dts_scores, dts_sels, data_to_score, alg_name,
out_dir, logger, seed, bins):
def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
alg_name, out_dir, logger, seed, bins):
if(not(os.path.exists(out_dir))):
os.makedirs(out_dir)

Expand Down
5 changes: 3 additions & 2 deletions dora_exp_pipeline/outlier_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def run(self, dtf: np.ndarray, dts: np.ndarray, dts_ids: list, out_dir: str,
for res_org_name, res_org_params in results_org_dict.items():
res_org_method = get_res_org_method(res_org_name)
res_org_method.run(results['dts_ids'], results['scores'],
results['sel_ind'], dts, self._ranking_alg_name,
sub_dir, logger, seed, top_n, **res_org_params)
results['sel_ind'], dtf, dts,
self._ranking_alg_name, sub_dir, logger, seed,
top_n, **res_org_params)

@staticmethod
def dict_to_str(params_dict: dict()) -> str:
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
'tensorflow==2.5.1',
'tensorflow-probability==0.13.0',
'tensorflow-addons==0.13.0',
'sklearn-som==1.1.0'
'sklearn-som==1.1.0',
'sortedcontainers==2.4.0',
'dill==0.3.4'
],
provide=[
'dora_exp_pipeline'
Expand Down

0 comments on commit ea21896

Please sign in to comment.