Skip to content

Commit

Permalink
Popv fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
axdanbol committed Mar 26, 2024
1 parent 9d1fd33 commit a005bdb
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 43 deletions.
2 changes: 1 addition & 1 deletion containers/popv/context/download-models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ MODELS_ID=${1:?"A zenodo models id must be provided to download!"}
MODELS_DIR=${2:-"./popv/models"}

mkdir -p $MODELS_DIR
zenodo_get $MODELS_ID -o $MODELS_DIR
zenodo_get $MODELS_ID -o $MODELS_DIR --continue-on-error --retry 2 --pause 30

for ARCHIVE in $MODELS_DIR/*.tar.gz; do
MODEL=$(basename -s .tar.gz $ARCHIVE)
Expand Down
2 changes: 1 addition & 1 deletion containers/popv/context/download-reference-data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ REFERENCE_DATA_ID=${1:?"A zenodo reference data id must be provided to download!
REFERENCE_DATA_DIR=${2:-"./popv/reference-data"}

mkdir -p $REFERENCE_DATA_DIR
zenodo_get $REFERENCE_DATA_ID -o $REFERENCE_DATA_DIR
zenodo_get $REFERENCE_DATA_ID -o $REFERENCE_DATA_DIR --continue-on-error --retry 2 --pause 30
108 changes: 93 additions & 15 deletions containers/popv/context/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,89 @@
from pathlib import Path

import anndata
import numpy
import popv
import celltypist
import h5py
import numpy as np
import pandas as pd
import scanpy
import scipy.sparse as sp_sparse
import scvi.data.fields._layer_field as scvi_layer_field
import torch

import popv
from src.algorithm import Algorithm, RunResult, add_common_arguments
from src.util.layers import set_data_layer

# From https://github.com/scverse/scvi-tools/blob/1.1.2/scvi/data/_utils.py#L15
try:
# anndata >= 0.10
from anndata.experimental import CSCDataset, CSRDataset

SparseDataset = (CSRDataset, CSCDataset)
except ImportError:
from anndata._core.sparse_dataset import SparseDataset


# From https://github.com/scverse/scvi-tools/blob/1.1.2/scvi/data/_utils.py#L248
# But with jax operations replaced by regular numpy calls
def _check_nonnegative_integers(
data: t.Union[pd.DataFrame, np.ndarray, sp_sparse.spmatrix, h5py.Dataset],
n_to_check: int = 20,
):
"""Approximately checks values of data to ensure it is count data."""
# for backed anndata
if isinstance(data, h5py.Dataset) or isinstance(data, SparseDataset):
data = data[:100]

if isinstance(data, np.ndarray):
data = data
elif issubclass(type(data), sp_sparse.spmatrix):
data = data.data
elif isinstance(data, pd.DataFrame):
data = data.to_numpy()
else:
raise TypeError("data type not understood")

ret = True
if len(data) != 0:
inds = np.random.choice(len(data), size=(n_to_check,))
# Start of replacements
data = data.flat[inds]
negative = np.any(data < 0)
non_integer = np.any(data % 1 != 0)
# End of replacements
ret = not (negative or non_integer)
return ret


def _fix_jax_segfault():
"""Fixes a segfault that can happen inside docker containers
when running both knn_on_scvi and scanvi.
The error is caused by a race condition or data corruption in jax
when the algorithms load their respective model files.
I suspect there might be a slight version mismatch or similar when
creating the docker container but for now I just monkey patch the offending calls.
"""
scvi_layer_field._check_nonnegative_integers = _check_nonnegative_integers


def _fix_celltypist_forced_models_download(model_dir: Path):
"""Prevent celltypist from redownloading all models.
Celltypist's `Model.load` function always attempts to download
all models if it cannot detect at least one *.pkl (pickle serialized)
file in it's default models directory even when provided with
a direct path to a model file.
Monkey patching celltypist's models directory path to a directory
with at least on *.pkl file will trick it into not downloading the models.
Args:
model_dir (Path): Directory with at least one *.pkl file
"""
celltypist.models.models_path = model_dir


class PopvOrganMetadata(t.TypedDict):
model: str
Expand Down Expand Up @@ -45,20 +120,21 @@ def do_run(
options: PopvOptions,
) -> RunResult:
"""Annotate data using popv."""
_fix_jax_segfault()

data = scanpy.read_h5ad(matrix)
data = self.prepare_query(data, organ, metadata["model"], options)
popv.annotation.annotate_data(
data,
# TODO: onclass has been removed due to error in fast mode
# seen_result_key is not added to the result in fast mode but still expected during compute_consensus
# https://github.com/YosefLab/PopV/blob/main/popv/annotation.py#L64
# https://github.com/YosefLab/PopV/blob/main/popv/algorithms/_onclass.py#L199
# Also excludes celltypist since web requests are not available inside the docker container
methods=[
"knn_on_scvi",
"scanvi",
"svm",
"rf",
# "onclass",
"celltypist",
],
)

Expand All @@ -81,23 +157,25 @@ def prepare_query(
reference_data_path = self.find_reference_data(
options["reference_data_dir"], organ, model
)
model_path = self.find_model_dir(options["models_dir"], organ, model)
reference_data = scanpy.read_h5ad(reference_data_path)
n_samples_per_label = self.get_n_samples_per_label(reference_data, options)
data = self.normalize_var_names(data, options)
data = set_data_layer(data, options["query_layers_key"])

if options["query_layers_key"] in ('X', 'raw'):
if options["query_layers_key"] in ("X", "raw"):
options["query_layers_key"] = None
data.X = numpy.rint(data.X)
data.X = np.rint(data.X)

model_dir = self.find_model_dir(options["models_dir"], organ, model)
_fix_celltypist_forced_models_download(model_dir)

data = self.add_model_genes(data, model_path, options["query_layers_key"])
data = self.add_model_genes(data, model_dir, options["query_layers_key"])
data.var_names_make_unique()

query = popv.preprocessing.Process_Query(
data,
reference_data,
save_path_trained_models=str(model_path),
save_path_trained_models=f"{model_dir}/",
prediction_mode=options["prediction_mode"],
query_labels_key=options["query_labels_key"],
query_batch_key=options["query_batch_key"],
Expand All @@ -109,7 +187,7 @@ def prepare_query(
cl_obo_folder=f"{options['cell_ontology_dir']}/",
compute_embedding=True,
hvg=None,
use_gpu=False, # Using gpu with docker requires additional setup
accelerator="cpu", # Using gpu with docker/apptainer requires additional setup
)
return query.adata

Expand All @@ -128,8 +206,8 @@ def get_n_samples_per_label(
ref_labels_key = options["ref_labels_key"]
n_samples_per_label = options["samples_per_label"]
if ref_labels_key in reference_data.obs.columns:
n = numpy.min(reference_data.obs.groupby(ref_labels_key).size())
n_samples_per_label = numpy.max((n_samples_per_label, t.cast(int, n)))
n = np.min(reference_data.obs.groupby(ref_labels_key).size())
n_samples_per_label = np.max((n_samples_per_label, t.cast(int, n)))
return n_samples_per_label

def find_reference_data(self, dir: Path, organ: str, model: str) -> Path:
Expand Down Expand Up @@ -274,8 +352,8 @@ def add_model_genes(
Path.joinpath(model_path, "scvi/model.pt"), map_location="cpu"
)["var_names"]
n_obs_data = data.X.shape[0]
new_genes = set(numpy.setdiff1d(model_genes, data.var_names))
zeroes = numpy.zeros((n_obs_data, len(new_genes)))
new_genes = set(np.setdiff1d(model_genes, data.var_names))
zeroes = np.zeros((n_obs_data, len(new_genes)))
layers = {query_layers_key: zeroes} if query_layers_key else None
new_data = scanpy.AnnData(X=zeroes, var=new_genes, layers=layers)
new_data.obs_names = data.obs_names
Expand Down
51 changes: 26 additions & 25 deletions containers/popv/context/requirements-freeze.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,55 +3,55 @@ aiohttp==3.9.3
aiosignal==1.3.1
anndata==0.10.6
annoy==1.17.3
array_api_compat==1.4.1
array_api_compat==1.5.1
astunparse==1.6.3
attrs==23.2.0
bbknn==1.6.0
beautifulsoup4==4.12.3
celltypist==1.6.2
certifi==2024.2.2
charset-normalizer==3.3.2
chex==0.1.85
chex==0.1.86
click==8.1.7
contextlib2==21.6.0
contourpy==1.2.0
cycler==0.12.1
Cython==3.0.9
dm-tree==0.1.8
docrep==0.3.2
et-xmlfile==1.1.0
etils==1.7.0
etils==1.8.0
fbpca==1.0
filelock==3.13.1
flatbuffers==24.3.7
flax==0.8.1
fonttools==4.49.0
filelock==3.13.3
flatbuffers==24.3.25
flax==0.8.2
fonttools==4.50.0
frozenlist==1.4.1
fsspec==2024.2.0
fsspec==2024.3.1
gast==0.5.4
gdown==5.1.0
geosketch==1.2
google-pasta==0.2.0
grpcio==1.62.1
h5py==3.10.0
harmony-pytorch==0.1.8
huggingface-hub==0.21.4
huggingface-hub==0.22.1
idna==3.6
igraph==0.11.4
importlib_resources==6.2.0
importlib_resources==6.4.0
intervaltree==3.1.0
jax==0.4.25
jaxlib==0.4.25
Jinja2==3.1.3
joblib==1.3.2
keras==3.0.5
keras==3.1.1
kiwisolver==1.4.5
legacy-api-wrap==1.4
leidenalg==0.10.2
libclang==16.0.6
libclang==18.1.1
lightning==2.1.4
lightning-utilities==0.10.1
lightning-utilities==0.11.1
llvmlite==0.42.0
Markdown==3.5.2
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.3
Expand All @@ -67,7 +67,7 @@ namex==0.0.7
natsort==8.4.0
nest-asyncio==1.6.0
networkx==3.2.1
numba==0.59.0
numba==0.59.1
numpy==1.26.4
numpyro==0.14.0
nvidia-cublas-cu12==12.1.3.1
Expand All @@ -87,7 +87,8 @@ OnClass==1.3
openpyxl==3.1.2
opt-einsum==3.3.0
optax==0.2.1
orbax-checkpoint==0.5.5
optree==0.11.0
orbax-checkpoint==0.5.7
packaging==24.0
pandas==1.5.3
patsy==0.5.6
Expand All @@ -110,13 +111,13 @@ requests==2.31.0
rich==13.7.1
safetensors==0.4.2
scanorama==1.7.4
scanpy==1.9.8
scanpy==1.10.0
scikit-learn==1.1.3
scikit-misc==0.3.1
scipy==1.12.0
scvi-tools==1.1.2
seaborn==0.13.2
sentence-transformers==2.5.1
sentence-transformers==2.6.1
session_info==1.0.0
six==1.16.0
sortedcontainers==2.4.0
Expand All @@ -128,16 +129,16 @@ tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
tensorflow-io-gcs-filesystem==0.36.0
tensorstore==0.1.54
tensorstore==0.1.56
termcolor==2.4.0
texttable==1.7.0
threadpoolctl==3.3.0
threadpoolctl==3.4.0
tokenizers==0.15.2
toolz==0.12.1
torch==2.2.1
torchmetrics==1.3.1
torchmetrics==1.3.2
tqdm==4.66.2
transformers==4.38.2
transformers==4.39.1
triton==2.2.0
typing_extensions==4.10.0
umap-learn==0.5.5
Expand All @@ -146,5 +147,5 @@ Werkzeug==3.0.1
wget==3.2
wrapt==1.16.0
yarl==1.9.4
zenodo-get==1.4.0
zipp==3.17.0
zenodo-get==1.5.1
zipp==3.18.1
2 changes: 1 addition & 1 deletion containers/popv/context/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
popv==0.4.*
zenodo_get==1.4.0
zenodo_get==1.5.*

0 comments on commit a005bdb

Please sign in to comment.