Skip to content

Commit

Permalink
updated the tutorial of COMMOT and Flowsig
Browse files Browse the repository at this point in the history
  • Loading branch information
Starlitnightly committed Sep 25, 2024
1 parent 7833649 commit a2047cd
Show file tree
Hide file tree
Showing 7 changed files with 1,340 additions and 6 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ The table contains the tools have been published
<td align="center">COMPOSITE<br><a href="https://github.com/CHPGenetics/COMPOSITE/">📦</a> <a href="https://www.nature.com/articles/s41467-024-49448-x#Abs1">📖</a></td>
<td align="center">mellon<br><a href="https://github.com/settylab/mellon">📦</a> <a href="https://www.nature.com/articles/s41592-024-02302-w">📖</a></td>
<td align="center">starfysh<br><a href="https://github.com/azizilab/starfysh">📦</a> <a href="http://dx.doi.org/10.1038/s41587-024-02173-8">📖</a></td>
<td align="center">COMMOT<br><a href="https://github.com/zcang/COMMOT">📦</a> <a href="https://www.nature.com/articles/s41592-022-01728-4">📖</a></td>
<td align="center">flowsig<br><a href="https://github.com/axelalmet/flowsig">📦</a> <a href="https://doi.org/10.1038/s41592-024-02380-w">📖</a></td>
<td align="center">pyWGCNA<br><a href="https://github.com/mortazavilab/PyWGCNA">📦</a> <a href="https://doi.org/10.1093/bioinformatics/btad415">📖</a></td>
</tr>
</table>
</div>
Expand Down
13 changes: 10 additions & 3 deletions omicverse/externel/flowsig/tools/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def run_gsp(adata: ad.AnnData,
if use_spatial:

# Define the blocks for spatial block bootstrapping
block_clusters = sorted(adata.obs[block_key].unique().tolist())
#block_clusters = sorted(adata.obs[block_key].unique().tolist())
adata.obs[block_key]=adata.obs[block_key].astype('category')
block_clusters = adata.obs[block_key].cat.categories.tolist()


for block in block_clusters:
block_indices = np.where(adata.obs[block_key] == block)[0] # Sample only those cells within the block
Expand Down Expand Up @@ -118,7 +121,8 @@ def run_utigsp(adata: ad.AnnData,
control_resampled = control_samples.copy()

# Define the blocks for bootstrapping
block_clusters_control = sorted(adata_control.obs[block_key].unique().tolist())
adata_control.obs[block_key]=adata_control.obs[block_key].astype('category')
block_clusters_control = adata_control.obs[block_key].cat.categories.tolist()

for block in block_clusters_control:

Expand All @@ -133,7 +137,10 @@ def run_utigsp(adata: ad.AnnData,

adata_pert = adata_perturbed[adata_perturbed.obs[condition_key] == pert]

block_clusters_pert = sorted(adata_pert.obs[block_key].unique().tolist())
adata_pert.obs[block_key]=adata_pert.obs[block_key].astype('category')
block_clusters_pert = adata_pert.obs[block_key].cat.categories.tolist()

#block_clusters_pert = sorted(adata_pert.obs[block_key].unique().tolist())

for block in block_clusters_pert:

Expand Down
3 changes: 2 additions & 1 deletion omicverse/pl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from ._multi import *
from ._bulk import *
from ._space import *
from ._cpdb import *
from ._cpdb import *
from ._flowsig import *
262 changes: 262 additions & 0 deletions omicverse/pl/_flowsig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import numpy as np

import random
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.patches import FancyArrowPatch
def curved_line(x0, y0, x1, y1, eps=0.8, pointn=30):
import bezier
x2 = (x0 + x1) / 2.0 + 0.1 ** (eps + abs(x0 - x1)) * (-1) ** (random.randint(1, 4))
y2 = (y0 + y1) / 2.0 + 0.1 ** (eps + abs(y0 - y1)) * (-1) ** (random.randint(1, 4))
nodes = np.asfortranarray([
[x0, x2, x1],
[y0, y2, y1]
])
curve = bezier.Curve(nodes, degree=2)
s_vals = np.linspace(0.0, 1.0, pointn)
data = curve.evaluate_multi(s_vals)
x = data[0]
y = data[1]
return x, y

def curved_graph(_graph, pos=None, eps=0.2, pointn=30,
linewidth=2, alpha=0.3, color_dict=None):
ax = plt.gca()
for u, v in _graph.edges():
x0, y0 = pos[u]
x1, y1 = pos[v]
x, y = curved_line(x0, y0, x1, y1, eps=eps, pointn=pointn)

# 使用颜色字典为边着色
color = color_dict[u] if color_dict and u in color_dict else 'k'

# 绘制曲线部分
segments = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([segments[:-1], segments[1:]], axis=1)
lc = LineCollection(segments, linewidths=linewidth, alpha=alpha, color=color)
ax.add_collection(lc)

# 在曲线的最后添加箭头
arrow = FancyArrowPatch((x[-2], y[-2]), (x[-1], y[-1]),
connectionstyle="arc3,rad=0.2",
arrowstyle='-|>', mutation_scale=10,
color=color, linewidth=linewidth, alpha=1)
ax.add_patch(arrow)

ax.autoscale_view()

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import random
import scanpy as sc
import networkx as nx
import pandas as pd
import anndata
def plot_curve_network(G: nx.Graph, G_type_dict: dict, G_color_dict: dict, pos_type: str = 'spring', pos_dim: int = 2,
figsize: tuple = (4, 4), pos_scale: int = 10, pos_k=None, pos_alpha: float = 0.4,
node_size: int = 50, node_alpha: float = 0.6, node_linewidths: int = 1,
plot_node=None, plot_node_num: int = 20, node_shape=None,
label_verticalalignment: str = 'center_baseline', label_fontsize: int = 12,
label_fontfamily: str = 'Arial', label_fontweight: str = 'bold', label_bbox=None,
legend_bbox: tuple = (0.7, 0.05), legend_ncol: int = 3, legend_fontsize: int = 12,
legend_fontweight: str = 'bold', curve_awarg=None,ylim=(-0.5,0.5),xlim=(-3,3)):
"""
Plot network graph.
Arguments:
G: networkx graph
G_type_dict: dict, node type dict
G_color_dict: dict, node color dict
pos_type: str, node position type, 'spring' or 'kamada_kawai'
pos_dim: int, node position dimension, 2 or 3
figsize: tuple, figure size
pos_scale: int, node position scale
pos_k: float, node position k
pos_alpha: float, node position alpha
node_size: int, node size
node_alpha: float, node alpha
node_linewidths: float, node linewidths
plot_node: list, plot node list
plot_node_num: int, plot node number
label_verticalalignment: str, label verticalalignment
label_fontsize: int, label fontsize
label_fontfamily: str, label fontfamily
label_fontweight: str, label fontweight
label_bbox: tuple, label bbox
legend_bbox: tuple, legend bbox
legend_ncol: int, legend ncol
legend_fontsize: int, legend fontsize
legend_fontweight: str, legend fontweight
curve_awarg: dict, arguments for curved graph
"""
from adjustText import adjust_text
fig, ax = plt.subplots(figsize=figsize)

# Determine node positions
if pos_type == 'spring':
pos = nx.spring_layout(G, scale=pos_scale, k=pos_k)
elif pos_type == 'kamada_kawai':
pos = nx.kamada_kawai_layout(G, dim=pos_dim, scale=pos_scale)
else:
pos = pos_type

degree_dict = dict(G.degree(G.nodes()))

# Update color and type dictionaries
G_color_dict = {node: G_color_dict[node] for node in G.nodes}
G_type_dict = {node: G_type_dict[node] for node in G.nodes}

# Draw nodes with different shapes if specified
if node_shape is not None:
for node_type, shape in node_shape.items():
node_list = [node for node in G.nodes if G_type_dict[node] == node_type]
nx.draw_networkx_nodes(
G,
pos,
nodelist=node_list,
node_size=100,
node_color=[G_color_dict[v] for v in node_list],
alpha=node_alpha,
linewidths=node_linewidths,
node_shape=shape # Set node shape
)
resultant_pattern_graph_edge_colours = [G_color_dict[edge[0]] for edge in G.edges()]
edge_widths = [0.2*G[u][v]['weight'] for u,v in G.edges()]
nx.draw_networkx_edges(G, pos,
edge_color=resultant_pattern_graph_edge_colours,
width=edge_widths, alpha=0.75, connectionstyle="arc3,rad=0.2")
else:
nx.draw_networkx_nodes(
G,
pos,
nodelist=list(G_color_dict.keys()),
node_size=[degree_dict[v] * node_size for v in G],
node_color=list(G_color_dict.values()),
alpha=node_alpha,
linewidths=node_linewidths,
)
resultant_pattern_graph_edge_colours = [G_color_dict[edge[0]] for edge in G.edges()]
edge_widths = [0.2*G[u][v]['weight'] for u,v in G.edges()]
nx.draw_networkx_edges(G, pos,
edge_color=resultant_pattern_graph_edge_colours,
width=edge_widths, alpha=0.75, connectionstyle="arc3,rad=0.2")

plt.xlim(xlim[0],xlim[1])
plt.ylim(ylim[0],ylim[1])

# Determine hub genes
if plot_node is not None:
hub_gene = plot_node
else:
hub_gene = [i[0] for i in sorted(degree_dict.items(), key=lambda x: x[1], reverse=True)[:plot_node_num]]

# Draw curved edges if specified
#if curve_awarg is not None:
# curved_graph(G, pos, color_dict=G_color_dict, **curve_awarg)
#else:
# curved_graph(G, pos, color_dict=G_color_dict)

pos1 = {i: pos[i] for i in hub_gene}

# Add labels to hub genes
texts = [ax.text(pos1[i][0], pos1[i][1], i, fontdict={'size': label_fontsize, 'weight': label_fontweight, 'color': 'black'})
for i in hub_gene if 'ENSG' not in i]
print(texts)

import adjustText
if adjustText.__version__<='0.8':
adjust_text(texts,only_move={'text': 'xy'},arrowprops=dict(arrowstyle='->', color='red'),)
else:
adjust_text(texts,only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
arrowprops=dict(arrowstyle='->', color='red'))

ax.axis("off")


# Prepare legend
t = pd.DataFrame(index=G_type_dict.keys())
t['gene_type_dict'] = G_type_dict
t['gene_color_dict'] = G_color_dict
type_color_dict = {i: t.loc[t['gene_type_dict'] == i, 'gene_color_dict'].values[0] for i in t['gene_type_dict'].value_counts().index}

patches = [mpatches.Patch(color=type_color_dict[i], label="{:s}".format(i)) for i in type_color_dict.keys()]
plt.tight_layout()
plt.margins(x=0.3)
plt.margins(y=0)
#plt.legend(handles=patches, bbox_to_anchor=legend_bbox, ncol=legend_ncol, fontsize=legend_fontsize)
#leg = plt.gca().get_legend()
#ltext = leg.get_texts()
#plt.setp(ltext, fontsize=legend_fontsize, fontweight=legend_fontweight)

return fig, ax


def plot_flowsig_network(flow_network,
gem_plot,
figsize=(8,4),
curve_awarg={'eps':2},
node_shape={'GEM':'^','Sender':'o','Receptor':'o'},
**kwargs):

gem_li=[i for i in flow_network.nodes if 'GEM' in i]
receptor_li=[i for i,j in flow_network.edges if ('GEM' in j)and('GEM' not in i)]
sender_li=[j for i,j in flow_network.edges if ('GEM' in i)and('GEM' not in j)]
#gene_li=[i for i in flow_network.nodes if 'GEM' not in i]
G_type_dict=dict(zip(gem_li+receptor_li+sender_li,
['GEM' for i in range(len(gem_li))]+\
['Receptor' for i in range(len(receptor_li))]+\
['Sender' for i in range(len(sender_li))]))

gem_li=[i for i in flow_network.nodes if 'GEM' in i]
receptor_li=[i for i,j in flow_network.edges if ('GEM' in j)and('GEM' not in i)]
sender_li=[j for i,j in flow_network.edges if ('GEM' in i)and('GEM' not in j)]
#gene_li=[i for i in flow_network.nodes if 'GEM' not in i]
if len(gem_li)<28:
from ._palette import sc_color
G_color_dict=dict(zip(gem_li+receptor_li+sender_li,
[sc_color[i] for i in range(len(gem_li))]+\
['#c2c2c2' for i in range(len(receptor_li))]+\
['#a51616' for i in range(len(sender_li))]))
else:
sc_color=sc.pl.palettes.default_102
G_color_dict=dict(zip(gem_li+receptor_li+sender_li,
[sc_color[i] for i in range(len(gem_li))]+\
['#c2c2c2' for i in range(len(receptor_li))]+\
['#a51616' for i in range(len(sender_li))]))

#G = nx.Graph([(0, 1), (1, 2), (1, 3), (3, 4)])
import networkx as nx
layers = {"layer1": list(set(sender_li)), "layer3": list(set(receptor_li)),"layer2": list(set(gem_li)),}
pos = nx.multipartite_layout(flow_network, subset_key=layers,align='horizontal',scale=2.0)

#fig, ax = plt.subplots(figsize=(8,8))
sub_G=flow_network.subgraph(sender_li+receptor_li+gem_plot).copy()
# 获取所有节点和边
all_nodes = set(sub_G.nodes())
all_edges = sub_G.edges()
# 找到没有连接边的节点
nodes_with_edges = set()
for edge in all_edges:
nodes_with_edges.update(edge)

nodes_without_edges = all_nodes - nodes_with_edges

# 从图中删除这些节点
sub_G.remove_nodes_from(nodes_without_edges)

from matplotlib.collections import LineCollection
from matplotlib.patches import FancyArrowPatch
import networkx as nx
fig,ax=plot_curve_network(sub_G,
G_type_dict=G_type_dict,
G_color_dict=G_color_dict,
pos_type=pos,
figsize=figsize,
curve_awarg=curve_awarg,
node_shape=node_shape,
**kwargs)

plt.box(False)
return fig,ax

2 changes: 1 addition & 1 deletion omicverse/pl/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def cellstackarea(adata,celltype_clusters:str,groupby:str,
print('{}_colors'.format(celltype_clusters))
type_color_all=dict(zip(adata.obs[celltype_clusters].cat.categories,adata.uns['{}_colors'.format(celltype_clusters)]))
else:
if len(adata.obs[cluster_key].cat.categories)>28:
if len(adata.obs[celltype_clusters].cat.categories)>28:
type_color_all=dict(zip(adata.obs[celltype_clusters].cat.categories,sc.pl.palettes.default_102))
else:
type_color_all=dict(zip(adata.obs[celltype_clusters].cat.categories,sc.pl.palettes.zeileis_28))
Expand Down
Loading

0 comments on commit a2047cd

Please sign in to comment.