diff --git a/docs/notebooks/segger_tutorial.ipynb b/docs/notebooks/segger_tutorial.ipynb index ee4cebc..6a3c11b 100644 --- a/docs/notebooks/segger_tutorial.ipynb +++ b/docs/notebooks/segger_tutorial.ipynb @@ -16,6 +16,10 @@ "source": [ "# **Introduction to Segger**\n", "\n", + "\n", + "**Important note (Dec 2024):** As segger is currently undergoing constant development we highly recommend installing directly via github.\n", + "\n", + "\n", "Segger is a cutting-edge cell segmentation model specifically designed for **single-molecule resolved spatial omics** datasets. It addresses the challenge of accurately segmenting individual cells in complex imaging datasets, leveraging a unique approach based on graph neural networks (GNNs).\n", "\n", "The core idea behind Segger is to model both **nuclei** and **transcripts** as graph nodes, with edges connecting them based on their spatial proximity. This allows the model to learn from the co-occurrence of nucleic and cytoplasmic molecules, resulting in more refined and accurate cell boundaries. By using spatial information and GNNs, Segger achieves state-of-the-art performance in segmenting single cells in datasets such as 10X Xenium and MERSCOPE, outperforming traditional methods like Baysor and Cellpose.\n", @@ -87,11 +91,10 @@ }, "outputs": [], "source": [ - "from segger.data.io import XeniumSample\n", + "from segger.data.parquet.sample import STSampleParquet\n", "from segger.training.segger_data_module import SeggerDataModule\n", "from segger.training.train import LitSegger\n", - "from segger.prediction.predict import predict, load_model\n", - "from segger.data.utils import calculate_gene_celltype_abundance_embedding\n", + "from segger.prediction.predict_parquet import segment, load_model\n", "from lightning.pytorch.loggers import CSVLogger\n", "from pytorch_lightning import Trainer\n", "from pathlib import Path\n", @@ -112,29 +115,16 @@ "\n", "In this step, we generate the dataset required for Segger's cell segmentation tasks.\n", "\n", - "Segger relies on spatial transcriptomics data, combining **nuclei** and **transcripts** from single-cell resolved imaging datasets. These nuclei and transcript nodes are represented in a graph, and the spatial proximity of transcripts to nuclei is used to establish edges between them.\n", + "Segger relies on spatial transcriptomics data, combining staining **boundaries** (e.g., nuclei or membrane stainings) and **transcripts** from single-cell resolved imaging datasets. These nuclei and transcript nodes are represented in a graph, and the spatial proximity of transcripts to nuclei is used to establish edges between them.\n", "\n", - "To use Segger with a Xenium dataset, you need the **`transcripts.csv.gz`** and **`nucleus_boundaries.csv.gz`** files. The **transcripts** file contains spatial coordinates and information for each transcript, while the **nucleus boundaries** file defines the polygon boundaries of the nuclei. These files enable Segger to map transcripts to their respective nuclei and perform cell segmentation based on spatial relationships. Segger can also be extended to other platforms by modifying the column names or formats in the input files to match its expected structure, making it adaptable for various spatial transcriptomics technologies." + "To use Segger with a Xenium dataset, you need the **`transcripts.parquet`** and **`nucleus_boundaries.parquet`** (or **`cell_boundaries.parquet`**, in case the Xenium samples comes with the segmentation kit) files. The **transcripts** file contains spatial coordinates and information for each transcript, while the **boundaries** file defines the polygon boundaries of the nuclei or cells. These files enable segger to map transcripts to their respective nuclei and perform cell segmentation based on spatial relationships. Segger can also be extended to other platforms by modifying the column names or formats in the input files to match its expected structure, making it adaptable for various spatial transcriptomics technologies. See (this)[https://github.com/EliHei2/segger_dev/tree/main/src/segger/data/parquet/_settings] for Xenium settings." ] }, { "cell_type": "markdown", - "id": "4e5df7f3-7f36-45b4-b7da-b301513efbce", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-11T22:56:41.967336Z", - "iopub.status.busy": "2024-09-11T22:56:41.966988Z" - }, - "id": "4e5df7f3-7f36-45b4-b7da-b301513efbce" - }, - "source": [ - "To create the dataset, you need to specify the path to the **transcripts** file and the **nuclei boundaries** file. These are typically downloaded from a spatial transcriptomics dataset like the [Xenium Human Pancreatic Dataset](https://www.10xgenomics.com/products/xenium-human-pancreatic-dataset-explorer).\n", - "\n", - "- **`--transcripts_path`**: Path to the transcripts file, which contains single-cell transcriptomic data.\n", - "- **`--boundaries_path`**: Path to the boundaries file, most often representing the nuclei boundaries in the imaging dataset.\n", - "\n", - "If single cell RNA sequencing results are available, you can incorporate them as features for segger:" - ] + "id": "f488a0b7", + "metadata": {}, + "source": [] }, { "cell_type": "code", @@ -142,102 +132,7 @@ "id": "598b4b16", "metadata": {}, "outputs": [], - "source": [ - "# scrnaseq_file = Path('my_scRNAseq_file.h5ad')\n", - "# celltype_column = 'celltype_column'\n", - "# gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(\n", - "# sc.read(scrnaseq_file),\n", - "# celltype_column\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "edd35db3-56e4-4a3e-9309-f83133274d47", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-12T00:49:03.132088Z", - "iopub.status.busy": "2024-09-12T00:49:03.131930Z", - "iopub.status.idle": "2024-09-12T00:49:05.472827Z", - "shell.execute_reply": "2024-09-12T00:49:05.472394Z", - "shell.execute_reply.started": "2024-09-12T00:49:03.132072Z" - }, - "id": "edd35db3-56e4-4a3e-9309-f83133274d47" - }, - "outputs": [], - "source": [ - "# Paths to Xenium sample data and where to store Segger data\n", - "xenium_data_dir = Path('data_xenium')\n", - "segger_data_dir = Path('data_segger')\n", - "\n", - "# Setup Xenium sample to create dataset\n", - "xs = XeniumSample(\n", - " verbose=False,\n", - " # embedding_df=gene_celltype_abundance_embedding # uncomment if gene-celltype embeddings are available\n", - ")\n", - "xs.set_file_paths(\n", - " transcripts_path=xenium_data_dir / 'transcripts.parquet',\n", - " boundaries_path=xenium_data_dir / 'nucleus_boundaries.parquet',\n", - ")\n", - "xs.set_metadata()" - ] - }, - { - "cell_type": "markdown", - "id": "33bd04f6-c4e3-42f8-81b2-c1e483d9faaf", - "metadata": { - "id": "33bd04f6-c4e3-42f8-81b2-c1e483d9faaf" - }, - "source": [ - "The following parameters are used to build a tiled Segger dataset:\n", - "\n", - "- **`--processed_dir`**: Directory where the processed dataset will be saved.\n", - "- **`--x_size`, `--y_size`**: These parameters specify the size of the tiles used to divide the image. The size of the tiles determines how the spatial region is partitioned for processing.\n", - "- **`--d_x`, `--d_y`**: These define the step size of the spatial grid used to bin transcripts and nuclei into tiles.\n", - "- **`--r_tx`**: Specifies the radius used for graph construction. A smaller radius will connect transcripts to nearby nuclei, while a larger radius might connect them to more distant neighbors.\n", - "- **`--scale_boundaries`**: The factor by which to scale the boundary polygons. Suggested to keep `=1` when boundaries refer to nuclei.\n", - "- **`--k_tx`**: Defines the number of nearest neighbors considered when building graphs for transcripts (`k_tx`).\n", - "- **`--val_prob` and `--test_prob`**: These control the proportion of the dataset that will be set aside for validation and testing. For instance, `--val_prob 0.1` means 10% of the data will be used for validation.\n", - "- **`--compute_labels`**: When set to `True`, this flag triggers the computation of labels (cell assignments) for each transcript. Use False if you just plan to perform prediction using a pre-existing model.\n", - "\n", - "Once the dataset is processed, the output will be ready for training the Segger model.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c8cf7102-ad9c-4bd0-bbd7-61a1d73abccd", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-12T00:49:06.357990Z", - "iopub.status.busy": "2024-09-12T00:49:06.357793Z", - "iopub.status.idle": "2024-09-12T00:49:07.235307Z", - "shell.execute_reply": "2024-09-12T00:49:07.234925Z", - "shell.execute_reply.started": "2024-09-12T00:49:06.357975Z" - }, - "id": "c8cf7102-ad9c-4bd0-bbd7-61a1d73abccd", - "scrolled": true - }, - "outputs": [], - "source": [ - "try:\n", - " xs.save_dataset_for_segger(\n", - " processed_dir=segger_data_dir,\n", - " r_tx=5,\n", - " k_tx=15,\n", - " x_size=120,\n", - " y_size=120,\n", - " d_x=100,\n", - " d_y=100,\n", - " margin_x=10,\n", - " margin_y=10,\n", - " scale_boundaries=1,\n", - " num_workers=4, # change to your number of CPUs\n", - " )\n", - "except AssertionError as err:\n", - " print(f'Dataset already exists at {segger_data_dir}')" - ] + "source": [] }, { "cell_type": "markdown", @@ -246,13 +141,12 @@ "id": "9d2b090b" }, "source": [ - "### **1.2 Faster Dataset Creation with Segger**\n", + "### **1.1. Fast Dataset Creation with segger**\n", "\n", - "Segger introduces a faster, more efficient pipeline for processing spatial transcriptomics data. This method accelerates dataset creation, particularly for large datasets, by using **ND-tree-based spatial partitioning** and **parallel processing**. This results in a much faster preparation of the dataset, which is saved in PyTorch Geometric (PyG) format, similar to the previous method.\n", + "Segger introduces a fast and efficient pipeline for processing spatial transcriptomics data. This method accelerates dataset creation, particularly for large datasets, by using **ND-tree-based spatial partitioning** and **parallel processing**. This results in a much faster preparation of the dataset, which is saved in PyTorch Geometric (PyG) format, similar to the previous method.\n", "\n", "**Note**: The previous dataset creation method will soon be deprecated in favor of this optimized pipeline.\n", "\n", - "#### **Requirements for the Faster Pipeline**\n", "The pipeline requires the following inputs:\n", "\n", "- **base_dir**: The directory containing the raw dataset.\n", @@ -260,7 +154,6 @@ "\n", "The core improvements in this method come from the use of **ND-tree partitioning**, which splits the data efficiently into spatial regions, and **parallel processing**, which speeds up the handling of these regions across multiple CPU cores. For example, using this pipeline, the Xenium Human Pancreatic Dataset can be processed in just a few minutes when running with 16 workers.\n", "\n", - "#### **Running the Faster Dataset Creation Pipeline**\n", "Below is an example of how to create a dataset using the faster Segger pipeline:" ] }, @@ -272,9 +165,7 @@ "id": "vlDtoWZb24FJ" }, "outputs": [], - "source": [ - "from segger.data.parquet.sample import STSampleParquet" - ] + "source": [] }, { "cell_type": "code", @@ -291,7 +182,7 @@ "sample = STSampleParquet(\n", " base_dir=xenium_data_dir,\n", " n_workers=4,\n", - " sample_type='xenium',\n", + " sample_type='xenium', # this could be 'xenium_v2' in case one uses the cell boundaries from the segmentation kit.\n", " # weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available\n", ")\n", "\n", @@ -341,44 +232,51 @@ }, { "cell_type": "markdown", - "id": "9962e4b8-4028-4683-9b75-d674fa6fb01d", - "metadata": { - "id": "9962e4b8-4028-4683-9b75-d674fa6fb01d" - }, + "id": "70755046", + "metadata": {}, "source": [ - "# **2. Train your Segger Model**\n", + "### **1.2. Using custom gene embeddings**\n", "\n", - "The Segger model training process begins after the dataset has been created. This model is a **heterogeneous graph neural network (GNN)** designed to segment single cells by leveraging both nuclei and transcript data.\n", - "\n", - "Segger uses graph attention layers to propagate information across nodes (nuclei and transcripts) and refine cell boundaries. The model architecture includes initial embedding layers, attention-based graph convolutions, and residual connections for stable learning.\n", + "In the default mode, segger initially tokenizes transcripts based on their gene type simply in a one-hot manner. However, one can use other genes embeddings (e.g., pre-trained embeddings). The following example shows how one can employ a cell-type-annotated scRNAseq reference of the same tissue type (not necessary same sample or experiment) to embed genes based on their abaundance in different cell types:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bf18259", + "metadata": {}, + "outputs": [], + "source": [ + "from segger.data.utils import calculate_gene_celltype_abundance_embedding\n", + "scrnaseq_file = Path('my_scRNAseq_file.h5ad')\n", + "celltype_column = 'celltype_column'\n", + "gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(\n", + " sc.read(scrnaseq_file),\n", + " celltype_column\n", + ")\n", "\n", - "Segger leverages the **PyTorch Lightning** framework to streamline the training and evaluation of its graph neural network (GNN). PyTorch Lightning simplifies the training process by abstracting away much of the boilerplate code, allowing users to focus on model development and experimentation. It also supports multi-GPU training, mixed-precision, and efficient scaling, making it an ideal framework for training complex models like Segger.\n" + "sample = STSampleParquet(\n", + " base_dir=xenium_data_dir,\n", + " n_workers=4,\n", + " sample_type='xenium', # this could be 'xenium_v2' in case one uses the cell boundaries from the segmentation kit.\n", + " weights=gene_celltype_abundance_embedding, \n", + ")" ] }, { "cell_type": "markdown", - "id": "8cbf5be9-27f3-45c2-ab28-8d8ceb078745", + "id": "9962e4b8-4028-4683-9b75-d674fa6fb01d", "metadata": { - "id": "8cbf5be9-27f3-45c2-ab28-8d8ceb078745" + "id": "9962e4b8-4028-4683-9b75-d674fa6fb01d" }, "source": [ - "Key parameters for training:\n", - "- **`--data_dir`**: Directory containing the training data.\n", - "- **`--model_dir`**: Directory in which to store models.\n", - "- **`--epochs`**: Specifies the number of training epochs.\n", - "- **`--batch_size`**: Batch sizes for training and validation data.\n", - "- **`--learning_rate`**: The initial learning rate for the optimizer.\n", - "- **`--hidden_channels`**: Number of hidden channels in the GNN layers.\n", - "- **`--heads`**: Number of attention heads used in each graph convolutional layer.\n", - "- **`--init_emb`**: Sets the dimensionality of the initial embeddings applied to the input node features (e.g., transcripts). A higher embedding dimension may capture more feature complexity but also requires more computation.\n", - "- **`--out_channels`**: Specifies the number of output channels after the final graph attention layer, e.g. the final learned representations of the graph nodes.\n", + "# **2. Train your Segger Model**\n", "\n", - "Additional Options for Training the Segger Model:\n", + "The Segger model training process begins after the dataset has been created. This model is a **heterogeneous graph neural network (GNN)** designed to segment single cells by leveraging both nuclei and transcript data.\n", "\n", - "- **`--aggr`**: This option controls the aggregation method used in the graph convolution layers.\n", - "- **`--accelerator`**: Controls the hardware used for training, such as `cuda` for GPU training. This enables Segger to leverage GPU resources for faster training, especially useful for large datasets.\n", - "- **`--strategy`**: Defines the distributed training strategy, with `auto` allowing PyTorch Lightning to automatically configure the best strategy based on the hardware setup.\n", - "- **`--precision`**: Enables mixed precision training (e.g., `16-mixed`), which can speed up training and reduce memory usage while maintaining accuracy." + "Segger uses graph attention layers to propagate information across nodes (nuclei and transcripts) and refine cell boundaries. The model architecture includes initial embedding layers, attention-based graph convolutions, and residual connections for stable learning.\n", + "\n", + "Segger leverages the **PyTorch Lightning** framework to streamline the training and evaluation of its graph neural network (GNN). PyTorch Lightning simplifies the training process by abstracting away much of the boilerplate code, allowing users to focus on model development and experimentation. It also supports multi-GPU training, mixed-precision, and efficient scaling, making it an ideal framework for training complex models like Segger." ] }, { @@ -404,11 +302,11 @@ "metadata = ([\"tx\", \"bd\"], [(\"tx\", \"belongs\", \"bd\"), (\"tx\", \"neighbors\", \"tx\")])\n", "ls = LitSegger(\n", " num_tx_tokens=500,\n", - " init_emb=8,\n", - " hidden_channels=32,\n", - " out_channels=8,\n", - " heads=2,\n", - " num_mid_layers=2,\n", + " init_emb=8, \n", + " hidden_channels=64,\n", + " out_channels=16,\n", + " heads=4,\n", + " num_mid_layers=1,\n", " aggr='sum',\n", " metadata=metadata,\n", ")\n", @@ -432,7 +330,7 @@ " accelerator='cuda',\n", " strategy='auto',\n", " precision='16-mixed',\n", - " devices=1,\n", + " devices=1, # set higher number if more gpus are available\n", " max_epochs=100,\n", " default_root_dir=models_dir,\n", " logger=CSVLogger(models_dir),\n", @@ -456,6 +354,30 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "e6214a79", + "metadata": {}, + "source": [ + "Key parameters for training:\n", + "- **`--data_dir`**: Directory containing the training data.\n", + "- **`--model_dir`**: Directory in which to store models.\n", + "- **`--epochs`**: Specifies the number of training epochs.\n", + "- **`--batch_size`**: Batch sizes for training and validation data.\n", + "- **`--learning_rate`**: The initial learning rate for the optimizer.\n", + "- **`--hidden_channels`**: Number of hidden channels in the GNN layers.\n", + "- **`--heads`**: Number of attention heads used in each graph convolutional layer.\n", + "- **`--init_emb`**: Sets the dimensionality of the initial embeddings applied to the input node features (e.g., transcripts). A higher embedding dimension may capture more feature complexity but also requires more computation.\n", + "- **`--out_channels`**: Specifies the number of output channels after the final graph attention layer, e.g. the final learned representations of the graph nodes.\n", + "\n", + "Additional Options for Training the Segger Model:\n", + "\n", + "- **`--aggr`**: This option controls the aggregation method used in the graph convolution layers.\n", + "- **`--accelerator`**: Controls the hardware used for training, such as `cuda` for GPU training. This enables Segger to leverage GPU resources for faster training, especially useful for large datasets.\n", + "- **`--strategy`**: Defines the distributed training strategy, with `auto` allowing PyTorch Lightning to automatically configure the best strategy based on the hardware setup.\n", + "- **`--precision`**: Enables mixed precision training (e.g., `16-mixed`), which can speed up training and reduce memory usage while maintaining accuracy." + ] + }, { "cell_type": "markdown", "id": "9a7d20c6-ca16-4beb-b627-afb41e3fb491", @@ -537,45 +459,7 @@ "source": [ "# **3. Make Predictions**\n", "\n", - "Once the Segger model is trained, it can be used to make predictions on seen or unseen data. This step involves using a trained checkpoint to predict cell boundaries and refine transcript-nuclei associations.\n", - "\n", - "Key parameters for making predictions:\n", - "- **`--checkpoint_path`**: Path to the trained model checkpoint, which stores the learned weights.\n", - "- **`--batch_size`**: Batch size used during inference.\n", - "- **`--score_cut`**: Defines the score threshold for classifying predictions. Higher values of `score_cut` make the model more conservative in associating transcripts with nuclei.\n", - "- **`--receptive_field`**: These parameters once again define the nearest neighbors for nuclei (`k_bd`) and transcripts (`k_tx`) and their distances (`dist_bd` and `dist_tx`) during the prediction stage.\n", - "- **`--use_cc`**: Used when some **transcripts are not directly associated with any nucleus**—a common scenario when a nucleus isn't captured on the slide or within the field of view. In these cases, Segger uses **connected components (CC)** to group such \"nucleus-less\" transcripts into distinct cells. Even though these transcripts lack a directly associated nucleus, they likely still represent a real cell, and grouping them together ensures that these cells are not discarded.\n", - "\n", - "The predictions can be saved and visualized to assess the segmentation quality.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "d4279c71-4660-46fc-a9e5-834e25d31f53", - "metadata": { - "id": "d4279c71-4660-46fc-a9e5-834e25d31f53" - }, - "outputs": [], - "source": [ - "# Checkpoint directory for Lightning model above\n", - "model_version = 0\n", - "\n", - "# Load in latest checkpoint\n", - "model_path = models_dir / 'lightning_logs' / f'version_{model_version}'\n", - "model = load_model(model_path / 'checkpoints')\n", - "dm.setup()\n", - "\n", - "receptive_field = {'k_bd': 4, 'dist_bd': 12,'k_tx': 15, 'dist_tx': 3}\n", - "\n", - "# Perform segmentation (predictions)\n", - "segmentation = predict(\n", - " model,\n", - " dm.train_dataloader(),\n", - " score_cut=0.33,\n", - " receptive_field=receptive_field,\n", - " use_cc=False,\n", - ")" + "Once the Segger model is trained, it can be used to make predictions on seen (partially trained) data or be transfered to unseen data. This step involves using a trained checkpoint to predict cell boundaries and refine transcript-nuclei associations." ] }, { @@ -585,11 +469,6 @@ "id": "9807abf3" }, "source": [ - "### **3.2 Faster Prediction with Segger**\n", - "We introduce a faster and more efficient pipeline for making predictions using a segger model. This new method accelerates the segmentation process by using CUDA-accelerated **nearest neighbors search** using [CAGRA](https://docs.rapids.ai/api/cuvs/stable/python_api/neighbors_cagra/) and **parallel processing**.\n", - "\n", - "**Note**: The previous prediction method will soon be deprecated in favor of this optimized pipeline.\n", - "\n", "#### **Requirements for the Faster Prediction Pipeline**\n", "The pipeline requires the following inputs:\n", "\n", @@ -602,18 +481,6 @@ "Below is an example of how to run the faster Segger prediction pipeline using the command line:" ] }, - { - "cell_type": "code", - "execution_count": 17, - "id": "3e802c3f", - "metadata": { - "id": "3e802c3f" - }, - "outputs": [], - "source": [ - "from segger.prediction.predict_parquet import segment, load_model" - ] - }, { "cell_type": "code", "execution_count": null, @@ -653,7 +520,7 @@ " model,\n", " dm,\n", " save_dir='benchmarks',\n", - " seg_tag='segger_embedding_1001',\n", + " seg_tag='segger_output',\n", " transcript_file='data_xenium/transcripts.parquet',\n", " receptive_field=receptive_field,\n", " min_transcripts=5,\n", @@ -664,6 +531,12 @@ ")\n" ] }, + { + "cell_type": "markdown", + "id": "3aa5002c", + "metadata": {}, + "source": [] + }, { "cell_type": "markdown", "id": "0a823035", diff --git a/scripts/create_data_fast_sample.py b/scripts/create_data_fast_sample.py new file mode 100644 index 0000000..faa147c --- /dev/null +++ b/scripts/create_data_fast_sample.py @@ -0,0 +1,36 @@ +from segger.data.parquet.sample import STSampleParquet +from path import Path +from segger.data.utils import calculate_gene_celltype_abundance_embedding +import scanpy as sc + +xenium_data_dir = Path('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/') +segger_data_dir = Path('data_tidy/pyg_datasets/bc_fast_data_emb_major') + + +scrnaseq_file = Path('data_tidy/benchmarks/xe_rep1_bc/scRNAseq.h5ad') +celltype_column = 'celltype_major' +gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding( + sc.read(scrnaseq_file), + celltype_column +) + +sample = STSampleParquet( + base_dir=xenium_data_dir, + n_workers=4, + sample_type='xenium', + weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available +) + +sample.save( + data_dir=segger_data_dir, + k_bd=3, + dist_bd=15.0, + k_tx=20, + dist_tx=3, + tile_width=220, + tile_height=220, + neg_sampling_ratio=5.0, + frac=1.0, + val_prob=0.1, + test_prob=0.1, +) \ No newline at end of file diff --git a/scripts/predict_model_sample.py b/scripts/predict_model_sample.py index 00b133c..9624abc 100644 --- a/scripts/predict_model_sample.py +++ b/scripts/predict_model_sample.py @@ -16,8 +16,12 @@ import dask.dataframe as dd -segger_data_dir = Path("./data_tidy/pyg_datasets/bc_embedding_1001") -models_dir = Path("./models/bc_embedding_1001_small") + +seg_tag = "bc_fast_data_emb_major" +model_version = 1 + +segger_data_dir = Path('data_tidy/pyg_datasets') / seg_tag +models_dir = Path("./models") / seg_tag benchmarks_dir = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc") transcripts_file = "data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet" # Initialize the Lightning data module @@ -30,63 +34,28 @@ dm.setup() -model_version = 0 - # Load in latest checkpoint model_path = models_dir / "lightning_logs" / f"version_{model_version}" model = load_model(model_path / "checkpoints") -receptive_field = {"k_bd": 4, "dist_bd": 20, "k_tx": 5, "dist_tx": 3} +receptive_field = {"k_bd": 4, "dist_bd": 15, "k_tx": 5, "dist_tx": 3} segment( model, dm, save_dir=benchmarks_dir, - seg_tag="parquet_test_big", + seg_tag=seg_tag, transcript_file=transcripts_file, # file_format='anndata', receptive_field=receptive_field, min_transcripts=5, - score_cut=0.5, + score_cut=0.1, # max_transcripts=1500, cell_id_col="segger_cell_id", - use_cc=True, - knn_method="cuda", + use_cc=False, + knn_method="kd_tree", verbose=True, gpu_ids=["0"], # client=client ) - -# if __name__ == "__main__": -# cluster = LocalCUDACluster( -# # CUDA_VISIBLE_DEVICES="0", -# n_workers=1, -# dashboard_address=":8080", -# memory_limit='30GB', # Adjust based on system memory -# lifetime="2 hours", # Increase worker lifetime -# lifetime_stagger="75 minutes", -# local_directory='.', # Stagger worker restarts -# lifetime_restart=True # Automatically restart workers -# ) -# client = Client(cluster) - -# segment( -# model, -# dm, -# save_dir=benchmarks_dir, -# seg_tag='segger_embedding_0926_mega_0.5_20', -# transcript_file=transcripts_file, -# file_format='anndata', -# receptive_field = receptive_field, -# min_transcripts=5, -# score_cut=0.5, -# # max_transcripts=1500, -# cell_id_col='segger_cell_id', -# use_cc=False, -# knn_method='cuda', -# # client=client -# ) - -# client.close() -# cluster.close() diff --git a/scripts/sand.py b/scripts/sand.py new file mode 100644 index 0000000..447c5e6 --- /dev/null +++ b/scripts/sand.py @@ -0,0 +1,503 @@ +from pqdm.processes import pqdm +from tqdm import tqdm +from pathlib import Path +import pandas as pd +import numpy as np +import zarr +from scipy.spatial import ConvexHull +from typing import Dict, Any, Optional, List, Tuple +from shapely.geometry import Polygon, MultiPolygon +from shapely.ops import unary_union + + +def get_boundary(seg_cell, x="x_location", y="y_location"): + """Calculate the boundary for the cell.""" + if len(seg_cell) < 3: + return None + bi = BoundaryIdentification(seg_cell[[x, y]].values) + bi.calculate_part_1(plot=False) + bi.calculate_part_2(plot=False) + return bi.find_cycles() + + +def process_group(group, area_low, area_high): + """Process each group for parallel computation.""" + cell_incremental_id, (seg_cell_id, seg_cell) = group + # print(seg_cell) + # print(cell_incremental_id) + # print(seg_cell) + if len(seg_cell) < 5: + return None + # Cell boundary using `get_boundary` + cell_boundary = generate_boundary(seg_cell) + if isinstance(cell_boundary, MultiPolygon): + # polygons = sorted(cell_boundary, key=lambda p: p.area, reverse=True) + # Extract the largest polygon (by area) from MultiPolygon + # cell_boundary = polygons[0] + # cell_boundary = unary_union(cell_boundary) + # polygons = sorted(cell_boundary, key=lambda p: p.area, reverse=True) + # cell_boundary = polygons[0] + cell_boundary = max(cell_boundary, key=lambda p: p.area) + print(cell_boundary.area) + # print(cell_boundary.area) + if cell_boundary is None or cell_boundary.area > area_high or cell_boundary.area < area_low: + # print('**********************************') + # print(cell_boundary.area) + # print('**********************************') + return None + uint_cell_id = cell_incremental_id + 1 + # print(uint_cell_id) + # seg_nucleus = seg_cell[seg_cell["overlaps_nucleus"] == 1] + # Nucleus boundary using ConvexHull + # nucleus_boundary = None + # if len(seg_nucleus) >= 3: + # try: + # nucleus_boundary = ConvexHull(seg_nucleus[["x_location", "y_location"]]) + # except Exception: + # pass + # Prepare data for the final output + return { + "uint_cell_id": uint_cell_id, + "seg_cell_id": seg_cell_id, + "cell_boundary": cell_boundary, + # "nucleus_boundary": nucleus_boundary, + "cell_summary": { + "cell_centroid_x": seg_cell["x_location"].mean(), + "cell_centroid_y": seg_cell["y_location"].mean(), + "cell_area": cell_boundary.area, + # "nucleus_centroid_x": seg_cell["x_location"].mean(), + # "nucleus_centroid_y": seg_cell["y_location"].mean(), + # "nucleus_area": nucleus_boundary.area if nucleus_boundary else 0, + "z_level": (seg_cell.z_location.mean() // 3).round(0) * 3, + }, + "seg_mask_value": uint_cell_id, + } + + + +def get_coordinates(boundary): + """Extracts coordinates from a Polygon or MultiPolygon.""" + if isinstance(boundary, MultiPolygon): + # Combine coordinates from all polygons in the MultiPolygon + coords = [] + for polygon in boundary.geoms: + coords.extend(polygon.exterior.coords) + return coords + elif isinstance(boundary, Polygon): + # Return coordinates from a single Polygon + return list(boundary.exterior.coords) + + +def get_flatten_version(polygon_vertices, max_value=21): + """ + Flattens and standardizes the shape of polygon vertices to a fixed length. + + Args: + polygon_vertices (list): List of polygons where each polygon is a list of coordinates. + max_value (int): The fixed number of vertices per polygon. + + Returns: + np.array: A standardized array of polygons with exactly max_value vertices each. + """ + flattened = [] + for vertices in polygon_vertices: + # Pad or truncate each polygon to the max_value + if len(vertices) > max_value: + flattened.append(vertices[:max_value]) + else: + flattened.append(vertices + [(0.0, 0.0)] * (max_value - len(vertices))) + return np.array(flattened, dtype=np.float32) + +from segger.validation.xenium_explorer import * +from segger.prediction.boundary import * +def seg2explorer( + seg_df: pd.DataFrame, + source_path: str, + output_dir: str, + cells_filename: str = "seg_cells", + analysis_filename: str = "seg_analysis", + xenium_filename: str = "seg_experiment.xenium", + analysis_df: Optional[pd.DataFrame] = None, + draw: bool = False, + cell_id_columns: str = "seg_cell_id", + area_low: float = 10, + area_high: float = 100, + n_jobs: int = 4, +) -> None: + """Convert seg output to a format compatible with Xenium explorer.""" + source_path = Path(source_path) + storage = Path(output_dir) + # Group by cell_id_columns + grouped_by = list(seg_df.groupby(cell_id_columns)) + # Process groups in parallel with pqdm + # results = pqdm( + # tqdm(enumerate(grouped_by), desc="Processing Cells", total=len(grouped_by)), + # lambda group: process_group(group, area_low, area_high), + # n_jobs=n_jobs, + # ) + results = [] + for idx, group in tqdm(enumerate(grouped_by), desc="Processing Cells", total=len(grouped_by)): + try: + res = process_group((idx, group), area_low, area_high) + if isinstance(res, dict): # Only keep valid results + results.append(res) + else: + print(f"Invalid result at index {idx}: {res}") + except Exception as e: + print(f"Error processing group at index {idx}: {e}") + # print(results) + print(results[0]) + # Filter out None results + results = [res for res in results if res] + # print('********************************0') + # Extract processed data + cell_id2old_id = {res["uint_cell_id"]: res["seg_cell_id"] for res in results} + cell_id = [res["uint_cell_id"] for res in results] + cell_summary = [res["cell_summary"] for res in results] + # print('********************************1') + polygon_num_vertices = [ + len(res["cell_boundary"].exterior.coords) for res in results + ] + print(polygon_num_vertices) + # polygon_vertices = np.array( + # [list(res["cell_boundary"].exterior.coords) for res in results], + # dtype=object + # ) + polygon_vertices = np.array( + [list(res["cell_boundary"].exterior.coords) for res in results], + dtype=object + ) + # polygon_vertices = get_flatten_version( + # [get_coordinates(res["cell_boundary"]) for res in results], + # max_value=21, + # ) + # print(polygon_vertices) + # print('********************************2') + seg_mask_value = [res["seg_mask_value"] for res in results] + # Convert results to Zarr + cells = { + "cell_id": np.array([np.array(cell_id), np.ones(len(cell_id))], dtype=np.uint32).T, + "cell_summary": pd.DataFrame(cell_summary).values.astype(np.float64), + "polygon_num_vertices": np.array(polygon_num_vertices).astype(np.int32), + "polygon_vertices": np.array(polygon_vertices).astype(np.float32), + "seg_mask_value": np.array(seg_mask_value, dtype=np.int32), + } + print(len(cells['cell_id'])) + # Save cells data + existing_store = zarr.open(source_path / "cells.zarr.zip", mode="r") + new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") + new_store["cell_id"] = cells["cell_id"] + new_store["polygon_num_vertices"] = cells["polygon_num_vertices"] + new_store["polygon_vertices"] = cells["polygon_vertices"] + new_store["seg_mask_value"] = cells["seg_mask_value"] + new_store.attrs.update(existing_store.attrs) + new_store.attrs["number_cells"] = len(cells["cell_id"]) + print(new_store) + new_store.store.close() + + print(cells["polygon_vertices"]) + # # Save analysis data + if analysis_df is None: + analysis_df = pd.DataFrame([cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns]) + analysis_df["default"] = "seg" + zarr_df = pd.DataFrame([cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns]) + clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_columns) + clusters_names = [i for i in analysis_df.columns if i != cell_id_columns] + clusters_dict = { + cluster: { + j: i + for i, j in zip( + range(1, len(sorted(np.unique(clustering_df[cluster].dropna()))) + 1), + sorted(np.unique(clustering_df[cluster].dropna())), + ) + } + for cluster in clusters_names + } + new_zarr = zarr.open(storage / (analysis_filename + ".zarr.zip"), mode="w") + new_zarr.create_group("/cell_groups") + clusters = [[clusters_dict[cluster].get(x, 0) for x in list(clustering_df[cluster])] for cluster in clusters_names] + for i in range(len(clusters)): + new_zarr["cell_groups"].create_group(i) + indices, indptr = get_indices_indptr(np.array(clusters[i])) + new_zarr["cell_groups"][i].create_dataset("indices", data=indices) + new_zarr["cell_groups"][i].create_dataset("indptr", data=indptr) + new_zarr["cell_groups"].attrs.update( + { + "major_version": 1, + "minor_version": 0, + "number_groupings": len(clusters_names), + "grouping_names": clusters_names, + "group_names": [ + [x[0] for x in sorted(clusters_dict[cluster].items(), key=lambda x: x[1])] for cluster in clusters_names + ], + } + ) + new_zarr.store.close() + # Generate experiment file + generate_experiment_file( + template_path=source_path / "experiment.xenium", + output_path=storage / xenium_filename, + cells_name=cells_filename, + analysis_name=analysis_filename, + ) + + + + +def seg2explorer( + seg_df: pd.DataFrame, + source_path: str, + output_dir: str, + cells_filename: str = "seg_cells", + analysis_filename: str = "seg_analysis", + xenium_filename: str = "seg_experiment.xenium", + analysis_df: Optional[pd.DataFrame] = None, + draw: bool = False, + cell_id_columns: str = "seg_cell_id", + area_low: float = 10, + area_high: float = 100, +) -> None: + """Convert seg output to a format compatible with Xenium explorer. + + Args: + seg_df (pd.DataFrame): The seg DataFrame. + source_path (str): The source path. + output_dir (str): The output directory. + cells_filename (str): The filename for cells. + analysis_filename (str): The filename for analysis. + xenium_filename (str): The filename for Xenium. + analysis_df (Optional[pd.DataFrame]): The analysis DataFrame. + draw (bool): Whether to draw the plots. + cell_id_columns (str): The cell ID columns. + area_low (float): The lower area threshold. + area_high (float): The upper area threshold. + """ + import zarr + import json + source_path = Path(source_path) + storage = Path(output_dir) + cell_id2old_id = {} + cell_id = [] + cell_summary = [] + polygon_num_vertices = [[], []] + polygon_vertices = [[], []] + seg_mask_value = [] + tma_id = [] + grouped_by = seg_df.groupby(cell_id_columns) + for cell_incremental_id, (seg_cell_id, seg_cell) in tqdm(enumerate(grouped_by), total=len(grouped_by)): + if len(seg_cell) < 5: + continue + # print('****************1') + # cell_convex_hull = ConvexHull(seg_cell[["x_location", "y_location"]]) + # print('****************2') + # hull_vertices = [seg_cell[["x_location", "y_location"]].values[vertex] for vertex in cell_convex_hull.vertices] + # print('****************3') + # # Create a Shapely Polygon + # cell_convex_hull = Polygon(hull_vertices) + cell_convex_hull = get_boundary(seg_cell) + # print(cell_convex_hull) + if isinstance(cell_convex_hull, MultiPolygon): + # polygons = sorted(cell_boundary, key=lambda p: p.area, reverse=True) + # Extract the largest polygon (by area) from MultiPolygon + # cell_boundary = polygons[0] + # cell_convex_hull = unary_union(cell_convex_hull) + # print('****************1') + # polygons = sorted(cell_convex_hull.geoms, key=lambda p: p.area, reverse=True) + # cell_convex_hull = polygons[0] + continue + # print('****************2') + # cell_convex_hull = max(cell_convex_hull.geoms, key=lambda p: p.area) + # print(cell_convex_hull) + if cell_convex_hull.area > area_high: + continue + if cell_convex_hull.area < area_low: + continue + uint_cell_id = cell_incremental_id + 1 + cell_id2old_id[uint_cell_id] = seg_cell_id + seg_nucleous = seg_cell[seg_cell["overlaps_nucleus"] == 1] + if len(seg_nucleous) >= 3: + nucleus_convex_hull = ConvexHull(seg_nucleous[["x_location", "y_location"]]) + cell_id.append(uint_cell_id) + cell_summary.append( + { + "cell_centroid_x": seg_cell["x_location"].mean(), + "cell_centroid_y": seg_cell["y_location"].mean(), + "cell_area": cell_convex_hull.area, + "nucleus_centroid_x": seg_cell["x_location"].mean(), + "nucleus_centroid_y": seg_cell["y_location"].mean(), + "nucleus_area": cell_convex_hull.area, + "z_level": (seg_cell.z_location.mean() // 3).round(0) * 3, + } + ) + polygon_num_vertices[0].append(len(cell_convex_hull.exterior.coords)) + polygon_num_vertices[1].append(len(nucleus_convex_hull.vertices) if len(seg_nucleous) >= 3 else 0) + polygon_vertices[0].append(cell_convex_hull.exterior.coords) + polygon_vertices[1].append( + seg_nucleous[["x_location", "y_location"]].values[nucleus_convex_hull.vertices] + if len(seg_nucleous) >= 3 + else np.array([[], []]).T + ) + seg_mask_value.append(cell_incremental_id + 1) + cell_polygon_vertices = get_flatten_version(polygon_vertices[0], max_value=128) + nucl_polygon_vertices = get_flatten_version(polygon_vertices[1], max_value=128) + cells = { + "cell_id": np.array([np.array(cell_id), np.ones(len(cell_id))], dtype=np.uint32).T, + "cell_summary": pd.DataFrame(cell_summary).values.astype(np.float64), + "polygon_num_vertices": np.array( + [ + [min(x + 1, x + 1) for x in polygon_num_vertices[1]], + [min(x + 1, x + 1) for x in polygon_num_vertices[0]], + ], + dtype=np.int32, + ), + "polygon_vertices": np.array([nucl_polygon_vertices, cell_polygon_vertices]).astype(np.float32), + "seg_mask_value": np.array(seg_mask_value, dtype=np.int32), + } + existing_store = zarr.open(source_path / "cells.zarr.zip", mode="r") + new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") + new_store["cell_id"] = cells["cell_id"] + new_store["polygon_num_vertices"] = cells["polygon_num_vertices"] + new_store["polygon_vertices"] = cells["polygon_vertices"] + new_store["seg_mask_value"] = cells["seg_mask_value"] + new_store.attrs.update(existing_store.attrs) + new_store.attrs["number_cells"] = len(cells["cell_id"]) + new_store.store.close() + if analysis_df is None: + analysis_df = pd.DataFrame([cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns]) + analysis_df["default"] = "seg" + zarr_df = pd.DataFrame([cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns]) + clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_columns) + clusters_names = [i for i in analysis_df.columns if i != cell_id_columns] + clusters_dict = { + cluster: { + j: i + for i, j in zip( + range(1, len(sorted(np.unique(clustering_df[cluster].dropna()))) + 1), + sorted(np.unique(clustering_df[cluster].dropna())), + ) + } + for cluster in clusters_names + } + new_zarr = zarr.open(storage / (analysis_filename + ".zarr.zip"), mode="w") + new_zarr.create_group("/cell_groups") + clusters = [[clusters_dict[cluster].get(x, 0) for x in list(clustering_df[cluster])] for cluster in clusters_names] + for i in range(len(clusters)): + new_zarr["cell_groups"].create_group(i) + indices, indptr = get_indices_indptr(np.array(clusters[i])) + new_zarr["cell_groups"][i].create_dataset("indices", data=indices) + new_zarr["cell_groups"][i].create_dataset("indptr", data=indptr) + new_zarr["cell_groups"].attrs.update( + { + "major_version": 1, + "minor_version": 0, + "number_groupings": len(clusters_names), + "grouping_names": clusters_names, + "group_names": [ + [x[0] for x in sorted(clusters_dict[cluster].items(), key=lambda x: x[1])] for cluster in clusters_names + ], + } + ) + new_zarr.store.close() + generate_experiment_file( + template_path=source_path / "experiment.xenium", + output_path=storage / xenium_filename, + cells_name=cells_filename, + analysis_name=analysis_filename, + ) + + + + +ddf = dd.read_parquet('/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc/parquet_train_big_0.5_False_3_10_5_3_20241030/segger_transcripts.parquet').compute() +ddf = ddf.dropna() +ddf = ddf[ddf.segger_cell_id != 'None'] +ddf = ddf.sort_values('segger_cell_id') +df = ddf.iloc[:10000, :] + + + +df_path = Path('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep1_v9_segger.csv.gz') +df_v9 = dd.read_csv(df_path) +df_main = dd.read_parquet('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/transcripts.parquet') + +ddf = df_v9.merge(df_main, on='transcript_id') +ddf = ddf.compute() +ddf = ddf[ddf.segger_cell_id != 'None'] +ddf = ddf.sort_values('segger_cell_id') +df = ddf.loc[(ddf.x_location > 250) & (ddf.x_location < 1500) & (ddf.y_location > 500) & (ddf.y_location < 1500),:] + +# tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') +# ddf = tx_df.merge(df_main, on='transcript_id') + +seg2explorer( + seg_df = df, + source_path = 'data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs', + output_dir = 'data_tidy/explorer/rep1sis', + cells_filename = "segger_cells_seg_roi1", + analysis_filename = "segger_analysis_seg_roi1", + xenium_filename = "segger_experiment_seg_roi1.xenium", + analysis_df = None, + cell_id_columns = "segger_cell_id", + area_low = 10, + area_high= 1000, +) + + + +df = ddf.loc[(ddf.x_location > 1550) & (ddf.x_location < 3250) & (ddf.y_location > 2250) & (ddf.y_location < 3550),:] + +# tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') +# ddf = tx_df.merge(df_main, on='transcript_id') + +seg2explorer( + seg_df = df, + source_path = 'data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs', + output_dir = 'data_tidy/explorer/rep1sis', + cells_filename = "segger_cells_seg_roi2", + analysis_filename = "segger_analysis_seg_roi2", + xenium_filename = "segger_experiment_seg_roi2.xenium", + analysis_df = None, + cell_id_columns = "segger_cell_id", + area_low = 10, + area_high= 1000, +) + + + +df = ddf.loc[(ddf.x_location > 4000) & (ddf.x_location < 4500) & (ddf.y_location > 1000) & (ddf.y_location < 1500),:] + +# tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') +# ddf = tx_df.merge(df_main, on='transcript_id') + +seg2explorer( + seg_df = df, + source_path = 'data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs', + output_dir = 'data_tidy/explorer/rep1sis', + cells_filename = "segger_cells_seg_roi3", + analysis_filename = "segger_analysis_seg_roi3", + xenium_filename = "segger_experiment_seg_roi3.xenium", + analysis_df = None, + cell_id_columns = "segger_cell_id", + area_low = 10, + area_high= 1000, +) + + +df = ddf.loc[(ddf.x_location > 1550) & (ddf.x_location < 3250) & (ddf.y_location > 2250) & (ddf.y_location < 3550),:] + +# tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') +# ddf = tx_df.merge(df_main, on='transcript_id') + +seg2explorer( + seg_df = df, + source_path = 'data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs', + output_dir = 'data_tidy/explorer/rep1sis', + cells_filename = "segger_cells_seg_roi2", + analysis_filename = "segger_analysis_seg_roi2", + xenium_filename = "segger_experiment_seg_roi2.xenium", + analysis_df = None, + cell_id_columns = "segger_cell_id", + area_low = 10, + area_high= 1000, +) + diff --git a/scripts/train_model_sample.py b/scripts/train_model_sample.py index 8b834cc..411cee8 100644 --- a/scripts/train_model_sample.py +++ b/scripts/train_model_sample.py @@ -15,13 +15,13 @@ import os -segger_data_dir = Path("./data_tidy/pyg_datasets/bc_embedding_1001") -models_dir = Path("./models/bc_embedding_1001_small") +segger_data_dir = segger_data_dir = Path('data_tidy/pyg_datasets/bc_fast_data_emb_minor') +models_dir = Path("./models/bc_fast_data_emb_minor") dm = SeggerDataModule( data_dir=segger_data_dir, - batch_size=4, - num_workers=2, + batch_size=2, + num_workers=1, ) dm.setup() @@ -30,10 +30,10 @@ ls = LitSegger( num_tx_tokens=500, init_emb=8, - hidden_channels=32, - out_channels=8, - heads=2, - num_mid_layers=2, + hidden_channels=64, + out_channels=16, + heads=4, + num_mid_layers=1, aggr="sum", metadata=metadata, ) @@ -44,7 +44,7 @@ strategy="auto", precision="16-mixed", devices=4, - max_epochs=200, + max_epochs=400, default_root_dir=models_dir, logger=CSVLogger(models_dir), ) diff --git a/src/segger/data/Untitled-4.py b/src/segger/data/Untitled-4.py new file mode 100644 index 0000000..51b3034 --- /dev/null +++ b/src/segger/data/Untitled-4.py @@ -0,0 +1,66 @@ +df_path = Path('/omics/groups/OE0606/internal/gleb/xenium/elyas/rep1/baysor_rep1.csv') +df = dd.read_csv(df_path) + +res = compute_cells_per_nucleus(df, new_cell_col='baysor_cell_id') + + +plot_gradient_nuclei_histogram(res, figures_path) + + + + + +df_path = Path('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep1_v9_segger.csv.gz') +df_v9 = dd.read_csv(df_path) +df_main = dd.read_parquet('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/transcripts.parquet') + +ddf = df_v9.merge(df_main, on='transcript_id') +# df = dd.read_parquet(segger_emb_path/'segger_transcripts.parquet') +ddf1 = ddf.compute() +res = compute_cells_per_nucleus(df.compute(), new_cell_col='segger_cell_id') +plot_gradient_nuclei_histogram(res, figures_path, 'segger') + +dfff = df.compute() +dfff = dfff.dropna() +gene_dropout = calculate_gene_dropout_rate(dfff) + + +plot_z_distance_distribution(dfff, figures_path, title='Z-Distance Distribution') + +plot_z_distance_boxplot(dfff, figures_path, title='Z-Distance Distribution') + + +tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') +ddf = tx_df.merge(df_main, on='transcript_id') + + + +from segger.prediction.boundary import generate_boundary +import geopandas as gpd +from tqdm import tqdm + +bb = generate_boundaries(ddf, x="x_location", y="y_location", cell_id="segger_cell_id", n_jobs=8) + + +from pqdm.processes import pqdm # or use pqdm.threads for threading-based parallelism + +# Modify the function to work with a single group to use with pqdm +def process_group(group): + cell_id, t = group + return { + "cell_id": cell_id, + "length": len(t), + "geom": generate_boundary(t, x="x_location", y="y_location") + } + +def generate_boundaries(df, x="x_location", y="y_location", cell_id="segger_cell_id", n_jobs=10): + # Group by cell_id + group_df = df.groupby(cell_id) + # Use pqdm to process each group in parallel + results = pqdm(tqdm(group_df, desc="Processing Groups"), process_group, n_jobs=n_jobs) + # Convert results to GeoDataFrame + return gpd.GeoDataFrame( + data=[[res["cell_id"], res["length"]] for res in results], + geometry=[res["geom"] for res in results], + columns=["cell_id", "length"], + ) diff --git a/src/segger/prediction/predict_parquet.py b/src/segger/prediction/predict_parquet.py index c6d39ef..e4d499b 100644 --- a/src/segger/prediction/predict_parquet.py +++ b/src/segger/prediction/predict_parquet.py @@ -189,7 +189,7 @@ def get_similarity_scores( to_type: str, receptive_field: dict, compute_sigmoid: bool = True, - knn_method: str = "cuda", + knn_method: str = "kd_tree", gpu_id: int = 0, # Added argument for GPU ID ) -> coo_matrix: """ @@ -217,19 +217,24 @@ def get_similarity_scores( # Step 1: Get embeddings from the model (on GPU) shape = batch[from_type].x.shape[0], batch[to_type].x.shape[0] + if from_type == to_type: + coords_1 = coords_2 = batch[to_type].pos + else: + coords_1 = batch[to_type].pos[:, :2] # 'tx' positions + coords_2 = batch[from_type].pos[:, :2] if knn_method == "kd_tree": # Compute edge indices using knn method (still on GPU) edge_index = get_edge_index( - batch[to_type].pos[:, :2].cpu(), # 'tx' positions - batch[from_type].pos[:, :2].cpu(), # 'bd' positions + coords_1.cpu(), + coords_2.cpu(), k=receptive_field[f"k_{to_type}"], dist=receptive_field[f"dist_{to_type}"], method=knn_method, ) else: edge_index = get_edge_index( - batch[to_type].pos[:, :2], # 'tx' positions - batch[from_type].pos[:, :2], # 'bd' positions + coords_1, + coords_2, k=receptive_field[f"k_{to_type}"], dist=receptive_field[f"dist_{to_type}"], method=knn_method, @@ -239,7 +244,15 @@ def get_similarity_scores( edge_index = coo_to_dense_adj(edge_index.T, num_nodes=shape[0], num_nbrs=receptive_field[f"k_{to_type}"]) with torch.no_grad(): - embeddings = model(batch.x_dict, batch.edge_index_dict) + if from_type != to_type: + embeddings = model(batch.x_dict, batch.edge_index_dict) + else: # to go with the inital embeddings for tx-tx + embeddings = {key: model.node_init[key](x) for key, x in batch.x_dict.items()} + norms = embeddings[to_type].norm(dim=1, keepdim=True) + # Avoid division by zero in case there are zero vectors + norms = torch.where(norms == 0, torch.ones_like(norms), norms) + # Normalize + embeddings[to_type] = embeddings[to_type] / norms def sparse_multiply(embeddings, edge_index, shape) -> coo_matrix: m = torch.nn.ZeroPad2d((0, 0, 0, 1)) # pad bottom with zeros