-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path2-extract-expression-features.py
91 lines (82 loc) · 2.56 KB
/
2-extract-expression-features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import os
import shutil
import subprocess
import sys
import tempfile
import pandas as pd
from anndata import read_h5ad
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_h5ad",
required=True,
help="Path to spatial AnnData.",
)
parser.add_argument(
"--output_h5ad",
required=True,
help="Path to save extracted expression features.",
)
parser.add_argument(
"--model",
required=True,
choices=["uce_4", "uce_33"],
default="uce_4",
help="Foundation model to use to extract expression features.",
)
parser.add_argument(
"--species",
required=True,
help="Species defined by the foundation model. Most commonly `human` or `mouse`.",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size for inference.",
)
args = parser.parse_args()
if args.model == "uce_4":
args.model_loc = "./model_files/4layer_model.torch"
args.nlayers = 4
elif args.model == "uce_33":
args.model_loc = "./model_files/33l_8ep_1024t_1280.torch"
args.nlayers = 33
return args
if __name__ == "__main__":
args = parse_args()
repo_root = os.path.dirname(__file__)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = os.path.join(tmp_dir, "") # trailing slash necessary for UCE pipeline
adata = read_h5ad(args.input_h5ad)
adata.obs = pd.DataFrame(index=adata.obs.index)
del adata.uns
del adata.obsm
tmp_h5ad = os.path.join(tmp_dir, "tmp.h5ad")
adata.write_h5ad(tmp_h5ad)
subprocess.run(
args=[
sys.executable, # python
"eval_single_anndata.py",
"--adata_path",
tmp_h5ad,
"--dir",
tmp_dir,
"--species",
args.species,
"--model_loc",
args.model_loc,
"--nlayers",
str(args.nlayers),
"--batch_size",
str(args.batch_size),
],
cwd=os.path.join(repo_root, "models", "UCE"),
)
uce_h5ad = os.path.join(tmp_dir, "tmp_uce_adata.h5ad")
uce_adata = read_h5ad(uce_h5ad)
uce_adata.obsm["X_expr"] = uce_adata.obsm["X_uce"]
del uce_adata.obsm["X_uce"]
uce_adata.write_h5ad(args.output_h5ad)
print(f"Moved {uce_h5ad} to {args.output_h5ad}")