Skip to content

Commit

Permalink
fix nonstring leiden
Browse files Browse the repository at this point in the history
  • Loading branch information
jsxlei committed Aug 22, 2024
1 parent ed675a1 commit 72a19a7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
5 changes: 3 additions & 2 deletions scalex/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def SCALEX(
"""
import torch
import numpy as np
import pandas as pd
import os
import scanpy as sc

Expand Down Expand Up @@ -225,8 +226,8 @@ def SCALEX(
if 'leiden' in adata.obs:
del adata.obs['leiden']

if outdir is not None:
adata.write(os.path.join(outdir, 'adata.h5ad'), compression='gzip')
# if outdir is not None:
# adata.write(os.path.join(outdir, 'adata.h5ad'), compression='gzip')

if not ignore_umap: #and adata.shape[0]<1e6:
log.info('Plot umap')
Expand Down
59 changes: 35 additions & 24 deletions scalex/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def embedding(
sep='_',
basis='X_umap',
size=10,
n_cols=4,
show=True,
):
"""
Expand Down Expand Up @@ -57,37 +58,47 @@ def embedding(
"""

if groups is None:
groups = adata.obs[groupby].cat.categories
_groups = adata.obs[groupby].cat.categories
else:
_groups = groups

# Create subplots
num_plots = len(groups)
fig, axes = plt.subplots(num_plots, 1, figsize=(5, 5 * num_plots), squeeze=False)
n_plots = len(_groups)
n_rows = (n_plots + n_cols - 1) // n_cols # Calculate number of rows

for ax, b in zip(axes.flatten(), groups):
adata.obs['tmp'] = adata.obs[color].astype(str)
adata.obs.loc[adata.obs[groupby]!=b, 'tmp'] = ''
if cond2 is not None:
adata.obs.loc[adata.obs[cond2]!=v2, 'tmp'] = ''
groups = list(adata[(adata.obs[groupby]==b) &
(adata.obs[cond2]==v2)].obs[color].astype('category').cat.categories.values)
size = min(size, 120000/len(adata[(adata.obs[groupby]==b) & (adata.obs[cond2]==v2)]))
else:
groups = list(adata[adata.obs[groupby]==b].obs[color].astype('category').cat.categories.values)
size = min(size, 120000/len(adata[adata.obs[groupby]==b]))
adata.obs['tmp'] = adata.obs['tmp'].astype('category')
if color_map is not None:
palette = [color_map[i] if i in color_map else 'gray' for i in adata.obs['tmp'].cat.categories]
else:
palette = None
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))

for j, ax in enumerate(axes.flatten()):
if j < n_plots:
b = _groups[j]
adata.obs['tmp'] = adata.obs[color].astype(str)
adata.obs.loc[adata.obs[groupby]!=b, 'tmp'] = ''
if cond2 is not None:
adata.obs.loc[adata.obs[cond2]!=v2, 'tmp'] = ''
groups = list(adata[(adata.obs[groupby]==b) &
(adata.obs[cond2]==v2)].obs[color].astype('category').cat.categories.values)
size = min(size, 120000/len(adata[(adata.obs[groupby]==b) & (adata.obs[cond2]==v2)]))
else:
groups = list(adata[adata.obs[groupby]==b].obs[color].astype('category').cat.categories.values)
size = min(size, 120000/len(adata[adata.obs[groupby]==b]))
adata.obs['tmp'] = adata.obs['tmp'].astype('category')
if color_map is not None:
palette = [color_map[i] if i in color_map else 'gray' for i in adata.obs['tmp'].cat.categories]
else:
palette = None

title = b if cond2 is None else v2+sep+b
title = b if cond2 is None else v2+sep+b

ax = sc.pl.embedding(adata, color='tmp', basis=basis, groups=groups, ax=ax, title=title, palette=palette, size=size,
legend_loc=legend_loc, legend_fontsize=legend_fontsize, legend_fontweight=legend_fontweight, wspace=0.25, show=False)

del adata.obs['tmp']
del adata.uns['tmp_colors']
else:
fig.delaxes(ax)

ax = sc.pl.embedding(adata, color='tmp', basis=basis, groups=groups, ax=ax, title=title, palette=palette, size=size,
legend_loc=legend_loc, legend_fontsize=legend_fontsize, legend_fontweight=legend_fontweight, wspace=0.25, show=False)


del adata.obs['tmp']
del adata.uns['tmp_colors']
if save:
plt.savefig(save, bbox_inches='tight')

Expand Down

0 comments on commit 72a19a7

Please sign in to comment.