Skip to content

Commit

Permalink
Merge pull request #73 from EliHei2/fix_cc
Browse files Browse the repository at this point in the history
fixed prediction to use 3 dims + updated the notebook
  • Loading branch information
EliHei2 authored Dec 28, 2024
2 parents b1c4b96 + 8424e76 commit e39ddee
Show file tree
Hide file tree
Showing 7 changed files with 731 additions and 271 deletions.
301 changes: 87 additions & 214 deletions docs/notebooks/segger_tutorial.ipynb

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions scripts/create_data_fast_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from segger.data.parquet.sample import STSampleParquet
from path import Path
from segger.data.utils import calculate_gene_celltype_abundance_embedding
import scanpy as sc

xenium_data_dir = Path('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/')
segger_data_dir = Path('data_tidy/pyg_datasets/bc_fast_data_emb_major')


scrnaseq_file = Path('data_tidy/benchmarks/xe_rep1_bc/scRNAseq.h5ad')
celltype_column = 'celltype_major'
gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
sc.read(scrnaseq_file),
celltype_column
)

sample = STSampleParquet(
base_dir=xenium_data_dir,
n_workers=4,
sample_type='xenium',
weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available
)

sample.save(
data_dir=segger_data_dir,
k_bd=3,
dist_bd=15.0,
k_tx=20,
dist_tx=3,
tile_width=220,
tile_height=220,
neg_sampling_ratio=5.0,
frac=1.0,
val_prob=0.1,
test_prob=0.1,
)
53 changes: 11 additions & 42 deletions scripts/predict_model_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
import dask.dataframe as dd


segger_data_dir = Path("./data_tidy/pyg_datasets/bc_embedding_1001")
models_dir = Path("./models/bc_embedding_1001_small")

seg_tag = "bc_fast_data_emb_major"
model_version = 1

segger_data_dir = Path('data_tidy/pyg_datasets') / seg_tag
models_dir = Path("./models") / seg_tag
benchmarks_dir = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc")
transcripts_file = "data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet"
# Initialize the Lightning data module
Expand All @@ -30,63 +34,28 @@
dm.setup()


model_version = 0

# Load in latest checkpoint
model_path = models_dir / "lightning_logs" / f"version_{model_version}"
model = load_model(model_path / "checkpoints")

receptive_field = {"k_bd": 4, "dist_bd": 20, "k_tx": 5, "dist_tx": 3}
receptive_field = {"k_bd": 4, "dist_bd": 15, "k_tx": 5, "dist_tx": 3}

segment(
model,
dm,
save_dir=benchmarks_dir,
seg_tag="parquet_test_big",
seg_tag=seg_tag,
transcript_file=transcripts_file,
# file_format='anndata',
receptive_field=receptive_field,
min_transcripts=5,
score_cut=0.5,
score_cut=0.1,
# max_transcripts=1500,
cell_id_col="segger_cell_id",
use_cc=True,
knn_method="cuda",
use_cc=False,
knn_method="kd_tree",
verbose=True,
gpu_ids=["0"],
# client=client
)


# if __name__ == "__main__":
# cluster = LocalCUDACluster(
# # CUDA_VISIBLE_DEVICES="0",
# n_workers=1,
# dashboard_address=":8080",
# memory_limit='30GB', # Adjust based on system memory
# lifetime="2 hours", # Increase worker lifetime
# lifetime_stagger="75 minutes",
# local_directory='.', # Stagger worker restarts
# lifetime_restart=True # Automatically restart workers
# )
# client = Client(cluster)

# segment(
# model,
# dm,
# save_dir=benchmarks_dir,
# seg_tag='segger_embedding_0926_mega_0.5_20',
# transcript_file=transcripts_file,
# file_format='anndata',
# receptive_field = receptive_field,
# min_transcripts=5,
# score_cut=0.5,
# # max_transcripts=1500,
# cell_id_col='segger_cell_id',
# use_cc=False,
# knn_method='cuda',
# # client=client
# )

# client.close()
# cluster.close()
Loading

0 comments on commit e39ddee

Please sign in to comment.