Skip to content

Commit

Permalink
(1) add sparsity parameter; (2) improve graph visualization #68
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenlujpl committed Oct 18, 2021
1 parent 48ea524 commit 0f2afe8
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions dora_exp_pipeline/dora_results_organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self):

def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n, n_clusters,
causal_graph):
causal_graph, sparsity):
if not os.path.exists(out_dir):
os.mkdir(out_dir)
if logger:
Expand Down Expand Up @@ -177,7 +177,7 @@ def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,

if causal_graph:
generate_causal_graphs(data_to_fit, data_to_cluster, groups,
out_dir, logger, seed)
out_dir, logger, seed, sparsity)


kmeans_cluster = KmeansCluster()
Expand All @@ -190,7 +190,7 @@ def __init__(self):

def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
outlier_alg_name, out_dir, logger, seed, top_n, n_clusters,
causal_graph):
causal_graph, sparsity):
if not os.path.exists(out_dir):
os.mkdir(out_dir)
if logger:
Expand All @@ -203,7 +203,8 @@ def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,
data_to_cluster.append(data_to_score[dts_ind])
data_to_cluster = np.array(data_to_cluster, dtype=float)

som = SOM(m=n_clusters, n=1, dim=len(data_to_cluster[0]))
som = SOM(m=n_clusters, n=1, dim=len(data_to_cluster[0]),
random_state=seed)
som.fit(data_to_cluster)
groups = som.predict(data_to_cluster)

Expand All @@ -215,16 +216,16 @@ def _run(self, data_ids, dts_scores, dts_sels, data_to_fit, data_to_score,

if causal_graph:
generate_causal_graphs(data_to_fit, data_to_cluster, groups,
out_dir, logger, seed)
out_dir, logger, seed, sparsity)


som_cluster = SOMCluster()
register_org_method(som_cluster)


def generate_causal_graphs(data_to_fit, data_to_cluster, cluster_groups,
out_dir, logger, seed):
causal_tags = ['feature-%d' % col for col in range(len(data_to_cluster[0]))]
out_dir, logger, seed, sparsity):
causal_tags = ['feat-%d' % col for col in range(len(data_to_cluster[0]))]
causal_tags = causal_tags + ['cluster']

if len(causal_tags) > 20:
Expand All @@ -243,7 +244,7 @@ def generate_causal_graphs(data_to_fit, data_to_cluster, cluster_groups,

# Define knowledge that forbids any connections from features to cluster
ken = knowledge.Knowledge()
block1 = ['feature-%d' % col for col in range(len(data_to_cluster[0]))]
block1 = ['feat-%d' % col for col in range(len(data_to_cluster[0]))]
block2 = ['cluster']
for i, i_label in enumerate(causal_tags):
for j, j_label in enumerate(causal_tags):
Expand All @@ -252,21 +253,34 @@ def generate_causal_graphs(data_to_fit, data_to_cluster, cluster_groups,

# Generate causal graphs
variables = list(range(len(data[0])))
score = SEMScore.SEMBicScore(2, dataset=data)
score = SEMScore.SEMBicScore(sparsity, dataset=data)
cs = fges.FGES(variables, score, knowledge=ken)
cs.search()
graph = cs.graph

# Assign names to graph nodes
node_labels = dict()
for idx, tag in enumerate(causal_tags):
if tag == 'cluster':
tag = 'cluster-%d' % group_label

node_labels.update({idx: tag})
graph = nx.relabel_nodes(graph, node_labels)

# Save graph
out_file = '%s/causal_graph_cluster_%d' % (out_dir, group_label)
pos = nx.circular_layout(graph, scale=2, dim=2)
nx.draw(graph, with_labels=True, pos=pos)
for edge in graph.edges():
if 'cluster-%d' % group_label in edge:
graph[edge[0]][edge[1]]['color'] = 'red'
else:
graph[edge[0]][edge[1]]['color'] = 'blue'

colors = [graph[u][v]['color'] for u, v in graph.edges()]

plt.figure(figsize=(12, 8))
nx.draw(graph, with_labels=True, edge_color=colors, pos=pos,
node_size=3000, node_color='lightgreen')

plt.savefig(out_file)
plt.clf()
Expand Down

0 comments on commit 0f2afe8

Please sign in to comment.