-
Notifications
You must be signed in to change notification settings - Fork 0
/
sankey_wrapper_plotly.py
64 lines (49 loc) · 1.8 KB
/
sankey_wrapper_plotly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
def add_link(source, target, value, label=""):
new_link = {}
new_link["source"] = source
new_link["target"] = target
new_link["value"] = value,
new_link["label"] = label,
return new_link
def get_nodes(df_links, colormap="Spectral"):
nodes = pd.concat([df_links["source"], df_links["target"]]).unique()
cmap = plt.get_cmap(colormap)
colors = cmap(np.linspace(0, 1, len(nodes)))
df_nodes = pd.DataFrame()
df_nodes["label"] = nodes
df_nodes["color"] = list(colors)
df_nodes["index"] = df_nodes.index
return df_nodes
def plot_sankey(df_links):
df_nodes = get_nodes(df_links)
enc_label = dict(zip(df_nodes["label"], df_nodes["index"]))
df_links = df_links.replace(enc_label)
dec_color = dict(zip(df_nodes["index"], df_nodes["color"]))
dec_color = dict((k, to_rgba_str(v, 0.7)) for k, v in dec_color.items())
df_links["color"] = df_links["target"].copy()
df_links["color"] = df_links["color"].replace(dec_color)
df_nodes["color"] = df_nodes["color"].apply(to_rgba_str)
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=25,
thickness=20,
label=df_nodes["label"],
color=df_nodes["color"]
),
link=dict(
source=df_links["source"],
target=df_links["target"],
value=df_links["value"],
label=df_links["label"],
color=df_links["color"]
))])
fig.update_layout(title_text="Example", font_size=10)
fig.write_html('sankey.html', auto_open=True)
def to_rgba_str(rgb_arr, alpha=None):
if alpha is None:
alpha = rgb_arr[-1]
return f"rgba({rgb_arr[0]}, {rgb_arr[1]}, {rgb_arr[2]}, {alpha})"