Skip to content

Commit

Permalink
fix plot_meta
Browse files Browse the repository at this point in the history
  • Loading branch information
jsxlei committed Aug 26, 2024
1 parent 72a19a7 commit 28f9179
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions scalex/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def embedding(
legend_fontweight='bold',
sep='_',
basis='X_umap',
size=10,
size=20,
n_cols=4,
show=True,
):
Expand Down Expand Up @@ -104,7 +104,7 @@ def embedding(

def plot_meta(
adata,
use_rep=None,
use_rep='latent',
color='celltype',
batch='batch',
colors=None,
Expand Down Expand Up @@ -146,15 +146,12 @@ def plot_meta(
"""
meta = []
name = []
color = []
if colors is None:
colors = ['#FFFF00', '#1CE6FF', '#FF34FF', '#FF4A46', '#008941', '#006FA6', '#A30059', '#FFDBE5', '#7A4900', '#0000A6',
'#63FFAC', '#B79762', '#004D43', '#8FB0FF', '#997D87', '#5A0007', '#809693', '#6A3A4C', '#1B4400', '#4FC601',
'#3B5DFF', '#4A3B53', '#FF2F80', '#61615A', '#BA0900', '#6B7900', '#00C2A0', '#FFAA92', '#FF90C9', '#B903AA',
'#D16100', '#DDEFFF', '#000035', '#7B4F4B', '#A1C299', '#300018', '#0AA6D8', '#013349', '#00846F', '#372101',
'#FFB500', '#C2FFED', '#A079BF', '#CC0744', '#C0B9B2', '#C2FF99', '#001E09']
color_list = []

adata.obs[color] = adata.obs[color].astype('category')
batches = np.unique(adata.obs[batch])
if colors is None:
colors = sns.color_palette("tab10", len(np.unique(adata.obs[color])))
for i,b in enumerate(batches):
for cat in adata.obs[color].cat.categories:
index = np.where((adata.obs[color]==cat) & (adata.obs[batch]==b))[0]
Expand All @@ -166,7 +163,7 @@ def plot_meta(
else:
meta.append(adata.X[index].mean(0))
name.append(cat)
color.append(colors[i])
color_list.append(colors[i])


meta = np.stack(meta)
Expand All @@ -177,8 +174,8 @@ def plot_meta(
mask[np.triu_indices_from(mask, k=1)] = True
grid = sns.heatmap(corr, mask=mask, xticklabels=name, yticklabels=name, annot=annot, # name -> []
cmap=cmap, square=True, cbar=True, vmin=vmin, vmax=vmax)
[ tick.set_color(c) for tick,c in zip(grid.get_xticklabels(),color) ]
[ tick.set_color(c) for tick,c in zip(grid.get_yticklabels(),color) ]
[ tick.set_color(c) for tick,c in zip(grid.get_xticklabels(),color_list) ]
[ tick.set_color(c) for tick,c in zip(grid.get_yticklabels(),color_list) ]
plt.xticks(rotation=45, horizontalalignment='right', fontsize=fontsize)
plt.yticks(fontsize=fontsize)

Expand Down

0 comments on commit 28f9179

Please sign in to comment.