From 6fb6e7ab318ef55e2b49e6cbe3f0034017416f73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=81NIEL=20UNYI?= Date: Wed, 8 Jan 2025 01:18:11 +0100 Subject: [PATCH 1/2] Update tutorial notebook --- docs/notebooks/segger_tutorial.ipynb | 1848 +++++++++++++------------- 1 file changed, 914 insertions(+), 934 deletions(-) diff --git a/docs/notebooks/segger_tutorial.ipynb b/docs/notebooks/segger_tutorial.ipynb index 6a3c11b..a15e7fa 100644 --- a/docs/notebooks/segger_tutorial.ipynb +++ b/docs/notebooks/segger_tutorial.ipynb @@ -1,940 +1,920 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "21ed4db6-5234-46b1-9f38-b5883ac88946", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-11T22:22:55.404267Z", - "iopub.status.busy": "2024-09-11T22:22:55.403876Z", - "iopub.status.idle": "2024-09-11T22:22:58.089917Z", - "shell.execute_reply": "2024-09-11T22:22:58.089303Z", - "shell.execute_reply.started": "2024-09-11T22:22:55.404248Z" - }, - "id": "21ed4db6-5234-46b1-9f38-b5883ac88946" - }, - "source": [ - "# **Introduction to Segger**\n", - "\n", - "\n", - "**Important note (Dec 2024):** As segger is currently undergoing constant development we highly recommend installing directly via github.\n", - "\n", - "\n", - "Segger is a cutting-edge cell segmentation model specifically designed for **single-molecule resolved spatial omics** datasets. It addresses the challenge of accurately segmenting individual cells in complex imaging datasets, leveraging a unique approach based on graph neural networks (GNNs).\n", - "\n", - "The core idea behind Segger is to model both **nuclei** and **transcripts** as graph nodes, with edges connecting them based on their spatial proximity. This allows the model to learn from the co-occurrence of nucleic and cytoplasmic molecules, resulting in more refined and accurate cell boundaries. By using spatial information and GNNs, Segger achieves state-of-the-art performance in segmenting single cells in datasets such as 10X Xenium and MERSCOPE, outperforming traditional methods like Baysor and Cellpose.\n", - "\n", - "Segger's workflow consists of:\n", - "1. **Dataset creation**: Converting raw transcriptomic data into a graph-based dataset.\n", - "2. **Training**: Training the Segger model on the graph to learn cell boundaries.\n", - "3. **Prediction**: Using the trained model to make predictions on new datasets.\n", - "\n", - "This tutorial will guide you through each step of the process, ensuring you can train and apply Segger for your own data." - ] - }, - { - "cell_type": "markdown", - "id": "XEY6CTzK0648", - "metadata": { - "id": "XEY6CTzK0648" - }, - "source": [ - "Installing segger from the GitHub repository:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "TIQnPzfx08Zr", - "metadata": { - "id": "TIQnPzfx08Zr" - }, - "outputs": [], - "source": [ - "!git clone https://github.com/EliHei2/segger_dev.git\n", - "%cd segger_dev\n", - "!pip install \".[rapids12]\" -q" - ] - }, - { - "cell_type": "markdown", - "id": "q3SNnImS09_N", - "metadata": { - "id": "q3SNnImS09_N" - }, - "source": [ - "Downloading the [Xenium Human Pancreatic Dataset](https://www.10xgenomics.com/products/xenium-human-pancreatic-dataset-explorer):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "Qjdt3f-U0_i9", - "metadata": { - "id": "Qjdt3f-U0_i9" - }, - "outputs": [], - "source": [ - "!mkdir data_xenium\n", - "%cd data_xenium\n", - "!wget https://cf.10xgenomics.com/samples/xenium/1.6.0/Xenium_V1_hPancreas_Cancer_Add_on_FFPE/Xenium_V1_hPancreas_Cancer_Add_on_FFPE_outs.zip\n", - "!unzip Xenium_V1_hPancreas_Cancer_Add_on_FFPE_outs.zip\n", - "%cd .." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "trM8h-Ek16sJ", - "metadata": { - "id": "trM8h-Ek16sJ" - }, - "outputs": [], - "source": [ - "from segger.data.parquet.sample import STSampleParquet\n", - "from segger.training.segger_data_module import SeggerDataModule\n", - "from segger.training.train import LitSegger\n", - "from segger.prediction.predict_parquet import segment, load_model\n", - "from lightning.pytorch.loggers import CSVLogger\n", - "from pytorch_lightning import Trainer\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "from matplotlib import pyplot as plt\n", - "import seaborn as sns\n", - "import scanpy as sc" - ] - }, - { - "cell_type": "markdown", - "id": "db009015-c379-4f50-97ed-81dca9df28ac", - "metadata": { - "id": "db009015-c379-4f50-97ed-81dca9df28ac" - }, - "source": [ - "# **1. Create your Segger Dataset**\n", - "\n", - "In this step, we generate the dataset required for Segger's cell segmentation tasks.\n", - "\n", - "Segger relies on spatial transcriptomics data, combining staining **boundaries** (e.g., nuclei or membrane stainings) and **transcripts** from single-cell resolved imaging datasets. These nuclei and transcript nodes are represented in a graph, and the spatial proximity of transcripts to nuclei is used to establish edges between them.\n", - "\n", - "To use Segger with a Xenium dataset, you need the **`transcripts.parquet`** and **`nucleus_boundaries.parquet`** (or **`cell_boundaries.parquet`**, in case the Xenium samples comes with the segmentation kit) files. The **transcripts** file contains spatial coordinates and information for each transcript, while the **boundaries** file defines the polygon boundaries of the nuclei or cells. These files enable segger to map transcripts to their respective nuclei and perform cell segmentation based on spatial relationships. Segger can also be extended to other platforms by modifying the column names or formats in the input files to match its expected structure, making it adaptable for various spatial transcriptomics technologies. See (this)[https://github.com/EliHei2/segger_dev/tree/main/src/segger/data/parquet/_settings] for Xenium settings." - ] - }, - { - "cell_type": "markdown", - "id": "f488a0b7", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "598b4b16", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "id": "9d2b090b", - "metadata": { - "id": "9d2b090b" - }, - "source": [ - "### **1.1. Fast Dataset Creation with segger**\n", - "\n", - "Segger introduces a fast and efficient pipeline for processing spatial transcriptomics data. This method accelerates dataset creation, particularly for large datasets, by using **ND-tree-based spatial partitioning** and **parallel processing**. This results in a much faster preparation of the dataset, which is saved in PyTorch Geometric (PyG) format, similar to the previous method.\n", - "\n", - "**Note**: The previous dataset creation method will soon be deprecated in favor of this optimized pipeline.\n", - "\n", - "The pipeline requires the following inputs:\n", - "\n", - "- **base_dir**: The directory containing the raw dataset.\n", - "- **data_dir**: The directory where the processed dataset (tiles in PyG format) will be saved.\n", - "\n", - "The core improvements in this method come from the use of **ND-tree partitioning**, which splits the data efficiently into spatial regions, and **parallel processing**, which speeds up the handling of these regions across multiple CPU cores. For example, using this pipeline, the Xenium Human Pancreatic Dataset can be processed in just a few minutes when running with 16 workers.\n", - "\n", - "Below is an example of how to create a dataset using the faster Segger pipeline:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "vlDtoWZb24FJ", - "metadata": { - "id": "vlDtoWZb24FJ" - }, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e933ebf3", - "metadata": { - "id": "e933ebf3" - }, - "outputs": [], - "source": [ - "xenium_data_dir = Path('data_xenium')\n", - "segger_data_dir = Path('data_segger')\n", - "\n", - "sample = STSampleParquet(\n", - " base_dir=xenium_data_dir,\n", - " n_workers=4,\n", - " sample_type='xenium', # this could be 'xenium_v2' in case one uses the cell boundaries from the segmentation kit.\n", - " # weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available\n", - ")\n", - "\n", - "sample.save(\n", - " data_dir=segger_data_dir,\n", - " k_bd=3,\n", - " dist_bd=15.0,\n", - " k_tx=3,\n", - " dist_tx=5.0,\n", - " tile_width=120,\n", - " tile_height=120,\n", - " neg_sampling_ratio=5.0,\n", - " frac=1.0,\n", - " val_prob=0.1,\n", - " test_prob=0.2,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "6ab27f9a", - "metadata": { - "id": "6ab27f9a" - }, - "source": [ - "#### **Parameters**\n", - "Here is a complete list of parameters you can use to control the dataset creation process:\n", - "\n", - "- **--base_dir**: Directory containing the raw spatial transcriptomics dataset.\n", - "- **--data_dir**: Directory where the processed Segger dataset (in PyG format) will be saved.\n", - "- **--sample_type**: (Optional) Specifies the type of dataset (e.g., \"xenium\" or \"merscope\"). Defaults to None.\n", - "- **--scrnaseq_file**: Path to the scRNAseq file (default: None).\n", - "- **--celltype_column**: Column name for cell type annotations in the scRNAseq file (default: None).\n", - "- **--k_bd**: Number of nearest neighbors for boundary nodes (default: 3).\n", - "- **--dist_bd**: Maximum distance for boundary neighbors (default: 15.0).\n", - "- **--k_tx**: Number of nearest neighbors for transcript nodes (default: 3).\n", - "- **--dist_tx**: Maximum distance for transcript neighbors (default: 5.0).\n", - "- **--tile_size**: Specifies the size of the tile. If provided, it overrides both tile_width and tile_height.\n", - "- **--tile_width**: Width of the tiles in pixels (ignored if tile_size is provided).\n", - "- **--tile_height**: Height of the tiles in pixels (ignored if tile_size is provided).\n", - "- **--neg_sampling_ratio**: Ratio of negative samples (default: 5.0).\n", - "- **--frac**: Fraction of the dataset to process (default: 1.0).\n", - "- **--val_prob**: Proportion of data used for validation split (default: 0.1).\n", - "- **--test_prob**: Proportion of data used for testing split (default: 0.2).\n", - "- **--n_workers**: Number of workers for parallel processing (default: 1)." - ] - }, - { - "cell_type": "markdown", - "id": "70755046", - "metadata": {}, - "source": [ - "### **1.2. Using custom gene embeddings**\n", - "\n", - "In the default mode, segger initially tokenizes transcripts based on their gene type simply in a one-hot manner. However, one can use other genes embeddings (e.g., pre-trained embeddings). The following example shows how one can employ a cell-type-annotated scRNAseq reference of the same tissue type (not necessary same sample or experiment) to embed genes based on their abaundance in different cell types:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3bf18259", - "metadata": {}, - "outputs": [], - "source": [ - "from segger.data.utils import calculate_gene_celltype_abundance_embedding\n", - "scrnaseq_file = Path('my_scRNAseq_file.h5ad')\n", - "celltype_column = 'celltype_column'\n", - "gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(\n", - " sc.read(scrnaseq_file),\n", - " celltype_column\n", - ")\n", - "\n", - "sample = STSampleParquet(\n", - " base_dir=xenium_data_dir,\n", - " n_workers=4,\n", - " sample_type='xenium', # this could be 'xenium_v2' in case one uses the cell boundaries from the segmentation kit.\n", - " weights=gene_celltype_abundance_embedding, \n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "9962e4b8-4028-4683-9b75-d674fa6fb01d", - "metadata": { - "id": "9962e4b8-4028-4683-9b75-d674fa6fb01d" - }, - "source": [ - "# **2. Train your Segger Model**\n", - "\n", - "The Segger model training process begins after the dataset has been created. This model is a **heterogeneous graph neural network (GNN)** designed to segment single cells by leveraging both nuclei and transcript data.\n", - "\n", - "Segger uses graph attention layers to propagate information across nodes (nuclei and transcripts) and refine cell boundaries. The model architecture includes initial embedding layers, attention-based graph convolutions, and residual connections for stable learning.\n", - "\n", - "Segger leverages the **PyTorch Lightning** framework to streamline the training and evaluation of its graph neural network (GNN). PyTorch Lightning simplifies the training process by abstracting away much of the boilerplate code, allowing users to focus on model development and experimentation. It also supports multi-GPU training, mixed-precision, and efficient scaling, making it an ideal framework for training complex models like Segger." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4db89cb4-d0eb-426a-a71f-d127926fa412", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-12T00:49:07.236043Z", - "iopub.status.busy": "2024-09-12T00:49:07.235854Z", - "iopub.status.idle": "2024-09-12T00:49:08.351946Z", - "shell.execute_reply": "2024-09-12T00:49:08.351565Z", - "shell.execute_reply.started": "2024-09-12T00:49:07.236028Z" - }, - "id": "4db89cb4-d0eb-426a-a71f-d127926fa412" - }, - "outputs": [], - "source": [ - "# Base directory to store Pytorch Lightning models\n", - "models_dir = Path('models')\n", - "\n", - "# Initialize the Lightning model\n", - "metadata = ([\"tx\", \"bd\"], [(\"tx\", \"belongs\", \"bd\"), (\"tx\", \"neighbors\", \"tx\")])\n", - "ls = LitSegger(\n", - " num_tx_tokens=500,\n", - " init_emb=8, \n", - " hidden_channels=64,\n", - " out_channels=16,\n", - " heads=4,\n", - " num_mid_layers=1,\n", - " aggr='sum',\n", - " metadata=metadata,\n", - ")\n", - "\n", - "# Initialize the Lightning data module\n", - "dm = SeggerDataModule(\n", - " data_dir=segger_data_dir,\n", - " batch_size=2,\n", - " num_workers=2,\n", - ")\n", - "\n", - "dm.setup()\n", - "\n", - "\n", - "# if you wish to use more than 1 device for training you should run this:\n", - "batch = dm.train[0]\n", - "ls.forward(batch)\n", - "\n", - "# Initialize the Lightning trainer\n", - "trainer = Trainer(\n", - " accelerator='cuda',\n", - " strategy='auto',\n", - " precision='16-mixed',\n", - " devices=1, # set higher number if more gpus are available\n", - " max_epochs=100,\n", - " default_root_dir=models_dir,\n", - " logger=CSVLogger(models_dir),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "207864b8-7e52-4add-a4a2-e95a4debdc06", - "metadata": { - "id": "207864b8-7e52-4add-a4a2-e95a4debdc06", - "scrolled": true - }, - "outputs": [], - "source": [ - "# Fit model\n", - "trainer.fit(\n", - " model=ls,\n", - " datamodule=dm\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "e6214a79", - "metadata": {}, - "source": [ - "Key parameters for training:\n", - "- **`--data_dir`**: Directory containing the training data.\n", - "- **`--model_dir`**: Directory in which to store models.\n", - "- **`--epochs`**: Specifies the number of training epochs.\n", - "- **`--batch_size`**: Batch sizes for training and validation data.\n", - "- **`--learning_rate`**: The initial learning rate for the optimizer.\n", - "- **`--hidden_channels`**: Number of hidden channels in the GNN layers.\n", - "- **`--heads`**: Number of attention heads used in each graph convolutional layer.\n", - "- **`--init_emb`**: Sets the dimensionality of the initial embeddings applied to the input node features (e.g., transcripts). A higher embedding dimension may capture more feature complexity but also requires more computation.\n", - "- **`--out_channels`**: Specifies the number of output channels after the final graph attention layer, e.g. the final learned representations of the graph nodes.\n", - "\n", - "Additional Options for Training the Segger Model:\n", - "\n", - "- **`--aggr`**: This option controls the aggregation method used in the graph convolution layers.\n", - "- **`--accelerator`**: Controls the hardware used for training, such as `cuda` for GPU training. This enables Segger to leverage GPU resources for faster training, especially useful for large datasets.\n", - "- **`--strategy`**: Defines the distributed training strategy, with `auto` allowing PyTorch Lightning to automatically configure the best strategy based on the hardware setup.\n", - "- **`--precision`**: Enables mixed precision training (e.g., `16-mixed`), which can speed up training and reduce memory usage while maintaining accuracy." - ] - }, - { - "cell_type": "markdown", - "id": "9a7d20c6-ca16-4beb-b627-afb41e3fb491", - "metadata": { - "id": "9a7d20c6-ca16-4beb-b627-afb41e3fb491" - }, - "source": [ - "### *Troubleshooting #1*\n", - "\n", - "In the cell below, we are visualizing key metrics from the model training and validation process. The plot displays **training loss**, **validation loss**, **F1 validation score**, and **AUROC validation score** over training steps. We expect to see the loss curves decreasing over time, signaling the model's improvement, and the F1 and AUROC scores increasing, reflecting improved segmentation performance as the model learns.\n", - "\n", - "If training is not working effectively, you might observe the following in the plot displaying **training loss**, **validation loss**, **F1 score**, and **AUROC**:\n", - "\n", - "- **Training loss not decreasing**: If the training loss remains high or fluctuates without a consistent downward trend, this indicates that the model is not learning effectively from the training data.\n", - "- **Validation loss decreases, then increases**: If validation loss decreases initially but starts to increase while training loss continues to drop, this could be a sign of **overfitting**, where the model is performing well on the training data but not generalizing to the validation data.\n", - "- **F1 score and AUROC not improving**: If these metrics remain flat or show inconsistent improvement, the model may be struggling to correctly segment cells or classify transcripts, indicating an issue with learning performance.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43a9c1a4-3898-407d-ac0f-f98b13694593", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-11T22:06:58.182616Z", - "iopub.status.busy": "2024-09-11T22:06:58.182357Z", - "iopub.status.idle": "2024-09-11T22:07:01.063645Z", - "shell.execute_reply": "2024-09-11T22:07:01.063184Z", - "shell.execute_reply.started": "2024-09-11T22:06:58.182599Z" - }, - "id": "43a9c1a4-3898-407d-ac0f-f98b13694593", - "outputId": "70ba8e1b-7814-497a-c8b6-8aa7295cd4d9" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Text(0.5, 0, 'Step')" - ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Evaluate results\n", - "model_version = 0 # 'v_num' from training output above\n", - "model_path = models_dir / 'lightning_logs' / f'version_{model_version}'\n", - "metrics = pd.read_csv(model_path / 'metrics.csv', index_col=1)\n", - "\n", - "fig, ax = plt.subplots(1,1, figsize=(2,2))\n", - "\n", - "for col in metrics.columns.difference(['epoch']):\n", - " metric = metrics[col].dropna()\n", - " ax.plot(metric.index, metric.values, label=col)\n", - "\n", - "ax.legend(loc=(1, 0.33))\n", - "ax.set_ylim(0, 1)\n", - "ax.set_xlabel('Step')" - ] - }, - { - "cell_type": "markdown", - "id": "e73687e1-ee8f-46e9-8bd2-1ddc571ef94b", - "metadata": { - "id": "e73687e1-ee8f-46e9-8bd2-1ddc571ef94b" - }, - "source": [ - "# **3. Make Predictions**\n", - "\n", - "Once the Segger model is trained, it can be used to make predictions on seen (partially trained) data or be transfered to unseen data. This step involves using a trained checkpoint to predict cell boundaries and refine transcript-nuclei associations." - ] - }, - { - "cell_type": "markdown", - "id": "9807abf3", - "metadata": { - "id": "9807abf3" - }, - "source": [ - "#### **Requirements for the Faster Prediction Pipeline**\n", - "The pipeline requires the following inputs:\n", - "\n", - "- **segger_data_dir**: The directory containing the processed Segger dataset (in PyG format).\n", - "- **models_dir**: The directory containing the trained Segger model checkpoints.\n", - "- **benchmarks_dir**: The directory where the segmentation results will be saved.\n", - "- **transcripts_file**: Path to the file containing the transcript data for prediction.\n", - "\n", - "#### **Running the Faster Prediction Pipeline**\n", - "Below is an example of how to run the faster Segger prediction pipeline using the command line:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "PEOtAs-t9CiY", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "PEOtAs-t9CiY", - "outputId": "8b7a5375-9ebc-4bb4-9421-254410319120" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Starting segmentation for segger_embedding_1001...\n" - ] - } - ], - "source": [ - "dm = SeggerDataModule(\n", - " data_dir='data_segger',\n", - " batch_size=1,\n", - " num_workers=4,\n", - ")\n", - "\n", - "dm.setup()\n", - "\n", - "model_version = 0\n", - "model_path = Path('models') / \"lightning_logs\" / f\"version_{model_version}\"\n", - "model = load_model(model_path / \"checkpoints\")\n", - "\n", - "receptive_field = {'k_bd': 4, 'dist_bd': 12, 'k_tx': 15, 'dist_tx': 3}\n", - "\n", - "segment(\n", - " model,\n", - " dm,\n", - " save_dir='benchmarks',\n", - " seg_tag='segger_output',\n", - " transcript_file='data_xenium/transcripts.parquet',\n", - " receptive_field=receptive_field,\n", - " min_transcripts=5,\n", - " cell_id_col='segger_cell_id',\n", - " use_cc=False,\n", - " knn_method='cuda',\n", - " verbose=True,\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "id": "3aa5002c", - "metadata": {}, - "source": [] - }, - { - "cell_type": "markdown", - "id": "0a823035", - "metadata": { - "id": "0a823035" - }, - "source": [ - "#### **Parameters**\n", - "Here is a detailed explanation of each parameter used in the faster prediction pipeline:\n", - "\n", - "- **--segger_data_dir**: The directory containing the processed Segger dataset, saved as PyTorch Geometric data objects, that will be used for prediction.\n", - "- **--models_dir**: The directory containing the trained Segger model checkpoints. These checkpoints store the learned weights required for making predictions.\n", - "- **--benchmarks_dir**: The directory where the segmentation results will be saved.\n", - "- **--transcripts_file**: Path to the *transcripts.parquet* file.\n", - "- **--batch_size**: Specifies the batch size for processing during prediction. Larger batch sizes speed up inference but use more memory (default: 1).\n", - "- **--num_workers**: Number of workers to use for parallel data loading (default: 1).\n", - "- **--model_version**: Version of the trained model to load for predictions, based on the version number from the training logs (default: 0).\n", - "- **--save_tag**: A tag used to name and organize the segmentation results (default: segger_embedding).\n", - "- **--min_transcripts**: The minimum number of transcripts required for segmentation (default: 5).\n", - "- **--cell_id_col**: The name of the column that stores the cell IDs (default: segger_cell_id).\n", - "- **--use_cc**: Enables the use of connected components (CC) for grouping transcripts that are not associated with any nucleus (default: False).\n", - "- **--knn_method**: Method for KNN (K-Nearest Neighbors) computation. Only option is \"cuda\" for this pipeline (default: cuda).\n", - "- **--file_format**: The format for saving the output segmentation data. Only option is \"anndata\" for this pipeline (default: anndata).\n", - "- **--k_bd**: Number of nearest neighbors for boundary nodes during segmentation (default: 4).\n", - "- **--dist_bd**: Maximum distance for boundary nodes during segmentation (default: 12.0).\n", - "- **--k_tx**: Number of nearest neighbors for transcript nodes during segmentation (default: 5).\n", - "- **--dist_tx**: Maximum distance for transcript nodes during segmentation (default: 5.0)." - ] - }, - { - "cell_type": "markdown", - "id": "b0917be9-4e82-4ba5-869d-5a9203721699", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-11T23:06:23.977884Z", - "iopub.status.busy": "2024-09-11T23:06:23.977517Z" - }, - "id": "b0917be9-4e82-4ba5-869d-5a9203721699" - }, - "source": [ - "### *Troubleshooting #2*\n", - "\n", - "In the cell below, we are visualizing the distribution of **Segger similarity scores** using a histogram. The **Segger similarity score** reflects how closely transcripts are associated with their respective nuclei in the segmentation process. **Higher scores** indicate stronger associations between transcripts and their nuclei, suggesting more accurate cell boundaries. **Lower scores** might indicate weaker associations, which could highlight potential segmentation errors or challenging regions in the data. We expect to see a large number of the scores clustering toward higher values, which would indicate strong overall performance of the model in associating transcripts with nuclei.\n", - "\n", - "The following would indicate potential issues with the model's predictions:\n", - "\n", - "- **A very large portion of scores near zero**: If many scores are concentrated at the lower end of the scale, this suggests that the model is frequently failing to associate transcripts with their corresponding nuclei, indicating poor segmentation quality.\n", - "- **No clear peak in the distribution**: If the histogram is flat or shows a wide, spread-out distribution, this could indicate that the model is struggling to consistently assign similarity scores, which may be a sign that the training process did not optimize the model correctly.\n", - "\n", - "Both cases would suggest that the model requires further tuning, such as adjusting hyperparameters, data preprocessing, or the training procedure (see below)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a450d3ca-2876-4f48-be89-761147b17387", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-11T22:07:04.216273Z", - "iopub.status.busy": "2024-09-11T22:07:04.215965Z", - "iopub.status.idle": "2024-09-11T22:07:08.177601Z", - "shell.execute_reply": "2024-09-11T22:07:08.177158Z", - "shell.execute_reply.started": "2024-09-11T22:07:04.216257Z" - }, - "id": "a450d3ca-2876-4f48-be89-761147b17387", - "outputId": "0576e0b8-4823-4701-b661-dc6b513841f2" - }, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig, ax = plt.subplots(1,1, figsize=(2,2))\n", - "sns.histplot(\n", - " segmentation['score'],\n", - " bins=50,\n", - " ax=ax,\n", - ")\n", - "ax.set_ylabel('Count')\n", - "ax.set_xlabel('Segger Similarity Score')\n", - "ax.set_yscale('log')" - ] - }, - { - "cell_type": "markdown", - "id": "5492fb96-bf8e-49d5-b40e-7e6b3f871bbe", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-11T22:34:15.990223Z", - "iopub.status.busy": "2024-09-11T22:34:15.988880Z" - }, - "id": "5492fb96-bf8e-49d5-b40e-7e6b3f871bbe" - }, - "source": [ - "#### The Importance of the Receptive Field in Segger\n", - "\n", - "The **receptive field** is a critical parameter in Segger, as it directly influences how the model interprets the spatial relationships between **transcripts** and **nuclei**. In the context of spatial transcriptomics, the receptive field determines the size of the neighborhood that each node (representing transcripts or nuclei) can \"see\" during graph construction and model training. Segger is particularly sensitive to the size of the receptive field because it affects the model's ability to propagate information across the graph. If the receptive field is too small, the model may fail to capture sufficient context for correct cell boundary delineation. Conversely, a very large receptive field may introduce noise by linking unrelated or distant nodes, reducing segmentation accuracy.\n", - "\n", - "#### Parameters affecting the receptive field in Segger:\n", - "- **`--r`**: This parameter defines the radius used when connecting transcripts to nuclei. A larger `r` expands the receptive field, linking more distant nodes. Fine-tuning this parameter helps ensure that Segger captures the right level of spatial interaction in the dataset.\n", - "- **`--k_bd` and `--k_tx`**: These control the number of nearest neighbors (nuclei and transcripts, respectively) considered in the graph. By increasing these values, the receptive field is effectively broadened, allowing more nodes to contribute to the information propagation.\n", - "- **`--dist_bd` and `--dist_tx`**: These parameters specify the maximum distances used to connect nuclei (`dist_bd`) and transcripts (`dist_tx`) to their neighbors during graph construction. They directly affect the receptive field by defining the cut-off distance for forming edges in the graph. Larger distance values expand the receptive field, connecting nodes that are further apart spatially. Careful tuning of these values is necessary to ensure that Segger captures relevant spatial relationships without introducing noise." - ] - }, - { - "cell_type": "markdown", - "id": "7ece1ac0-0708-45e2-87fc-1b25782831f8", - "metadata": { - "id": "7ece1ac0-0708-45e2-87fc-1b25782831f8" - }, - "source": [ - "# **4. Tune Parameters**" - ] - }, - { - "cell_type": "markdown", - "id": "896b8288-5287-4d10-a206-e68c0e4731c6", - "metadata": { - "id": "896b8288-5287-4d10-a206-e68c0e4731c6" - }, - "source": [ - "### Evaluating Receptive Field Parameters with Grid Search\n", - "\n", - "To evaluate the impact of different receptive field parameters in Segger, we use a **grid search** approach. The parameters `k_bd`, `k_tx`, `dist_bd`, and `dist_tx` (which control the number of neighbors and distances for nuclei and transcripts) are explored through various configurations defined in `param_space`. Each combination of these parameters is passed to the `trainable` function, which creates the dataset, trains the model, and makes predictions based on the specified receptive field.\n", - "\n", - "For each parameter combination:\n", - "1. A dataset is created with the specified receptive field.\n", - "2. The Segger model is trained on this dataset.\n", - "3. Predictions are made, and segmentation results are evaluated using the custom `evaluate` function. This function computes metrics like the fraction of assigned transcripts and average cell sizes.\n", - "\n", - "The results from each configuration are saved, allowing us to compare how different receptive field settings impact the model’s performance. This process enables a thorough search of the parameter space, optimizing the model for accurate segmentation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b0c1a7a8-acb2-4aae-8ae4-8aa9a4196717", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-12T01:10:47.781418Z", - "iopub.status.busy": "2024-09-12T01:10:47.781067Z", - "iopub.status.idle": "2024-09-12T01:10:48.706615Z", - "shell.execute_reply": "2024-09-12T01:10:48.706194Z", - "shell.execute_reply.started": "2024-09-12T01:10:47.781401Z" - }, - "id": "b0c1a7a8-acb2-4aae-8ae4-8aa9a4196717" - }, - "outputs": [], - "source": [ - "import itertools\n", - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0bd0803c-e58d-4f43-9627-d2c1ab187d5e", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-12T01:16:31.976312Z", - "iopub.status.busy": "2024-09-12T01:16:31.975947Z", - "iopub.status.idle": "2024-09-12T01:16:33.168389Z", - "shell.execute_reply": "2024-09-12T01:16:33.167956Z", - "shell.execute_reply.started": "2024-09-12T01:16:31.976295Z" - }, - "id": "0bd0803c-e58d-4f43-9627-d2c1ab187d5e" - }, - "outputs": [], - "source": [ - "tuning_dir = Path('path/to/tutorial/tuning/')\n", - "sampling_rate = 0.125" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b879a0b5-150c-4240-99ec-81075855aa52", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-12T01:16:33.169525Z", - "iopub.status.busy": "2024-09-12T01:16:33.169189Z", - "iopub.status.idle": "2024-09-12T01:16:34.147222Z", - "shell.execute_reply": "2024-09-12T01:16:34.146804Z", - "shell.execute_reply.started": "2024-09-12T01:16:33.169508Z" - }, - "id": "b879a0b5-150c-4240-99ec-81075855aa52", - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "# Fixed function arguments used for each trial\n", - "transcripts_path = xenium_data_dir / 'transcripts.parquet'\n", - "\n", - "boundaries_path = xenium_data_dir / 'nucleus_boundaries.parquet'\n", - "\n", - "dataset_kwargs = dict(\n", - " x_size=80, y_size=80, d_x=80, d_y=80, margin_x=10, margin_y=10,\n", - " num_workers=4, sampling_rate=sampling_rate,\n", - ")\n", - "\n", - "model_kwargs = dict(\n", - " metadata=(['tx', 'bd'], [('tx', 'belongs', 'bd'), ('tx', 'neighbors', 'tx')]),\n", - " num_tx_tokens=500, init_emb=8, hidden_channels=32, out_channels=8,\n", - " heads=2, num_mid_layers=2, aggr='sum',\n", - ")\n", - "\n", - "trainer_kwargs = dict(\n", - " accelerator='cuda', strategy='auto', precision='16-mixed', devices=1,\n", - " max_epochs=100,\n", - ")\n", - "\n", - "predict_kwargs = dict(score_cut=0.2, use_cc=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fbd831c9-3a50-4e3b-97d3-3c152ae01188", - "metadata": { - "id": "fbd831c9-3a50-4e3b-97d3-3c152ae01188", - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "def trainable(config):\n", - "\n", - " receptive_field = {k: config[k] for k in ['k_bd', 'k_tx', 'dist_bd', 'dist_tx']}\n", - "\n", - " # Dataset creation\n", - " xs = XeniumSample(verbose=False)\n", - " xs.set_file_paths(transcripts_path, boundaries_path)\n", - " xs.set_metadata()\n", - " try:\n", - " xs.save_dataset_for_segger(\n", - " processed_dir=config['data_dir'],\n", - " receptive_field=receptive_field,\n", - " **dataset_kwargs,\n", - " )\n", - " except:\n", - " pass\n", - "\n", - " # Model training\n", - " ls = LitSegger(**model_kwargs)\n", - " dm = SeggerDataModule(\n", - " data_dir=config['data_dir'],\n", - " batch_size=2,\n", - " num_workers=dataset_kwargs['num_workers'],\n", - " )\n", - " trainer = Trainer(\n", - " default_root_dir=config['model_dir'],\n", - " logger=CSVLogger(config['model_dir']),\n", - " **trainer_kwargs,\n", - " )\n", - " trainer.fit(model=ls, datamodule=dm)\n", - "\n", - " segmentation = predict(\n", - " load_model(config['model_dir']/'lightning_logs/version_0/checkpoints'),\n", - " dm.train_dataloader(),\n", - " receptive_field=receptive_field,\n", - " **predict_kwargs,\n", - " )\n", - "\n", - " metrics = evaluate(segmentation)\n", - "\n", - "\n", - "def evaluate(segmentation: pd.DataFrame, score_cut: float) -> pd.Series:\n", - "\n", - " assigned = segmentation['score'] > score_cut\n", - " metrics = pd.Series(dtype=float)\n", - " metrics['frac_assigned'] = assigned.mean()\n", - " cell_sizes = segmentation.groupby(assigned)['segger_cell_id'].value_counts()\n", - " assigned_avg = 0 if True not in cell_sizes.index else cell_sizes[True].mean()\n", - " cc_avg = 0 if False not in cell_sizes.index else cell_sizes[False].mean()\n", - " metrics['cell_size_assigned'] = assigned_avg\n", - " metrics['cell_size_cc'] = cc_avg\n", - " return metrics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba2dcc9a-3a06-4b84-a487-59a768eed5d5", - "metadata": { - "execution": { - "iopub.execute_input": "2024-09-12T01:16:35.184598Z", - "iopub.status.busy": "2024-09-12T01:16:35.184180Z", - "iopub.status.idle": "2024-09-12T01:19:55.171470Z", - "shell.execute_reply": "2024-09-12T01:19:55.170810Z", - "shell.execute_reply.started": "2024-09-12T01:16:35.184582Z" - }, - "id": "ba2dcc9a-3a06-4b84-a487-59a768eed5d5", - "scrolled": true - }, - "outputs": [], - "source": [ - "param_space = {\n", - " \"k_bd\": [3, 5, 10],\n", - " \"dist_bd\": [5, 10, 15, 20],\n", - " \"k_tx\": [3, 5, 10],\n", - " \"dist_tx\": [3, 5, 10],\n", - "}\n", - "\n", - "metrics = []\n", - "\n", - "for params in itertools.product(*param_space.values()):\n", - "\n", - " config = dict(zip(param_space.keys(), params))\n", - "\n", - " # Setup directories\n", - " trial_dir = tuning_dir / '_'.join([f'{k}={v}' for k, v in config.items()])\n", - "\n", - " data_dir = trial_dir / 'segger_data'\n", - " data_dir.mkdir(exist_ok=True, parents=True)\n", - " config['data_dir'] = data_dir\n", - "\n", - " model_dir = trial_dir / 'models'\n", - " model_dir.mkdir(exist_ok=True, parents=True)\n", - " config['model_dir'] = model_dir\n", - "\n", - " segmentation = trainable(config)\n", - " trial = evaluate(segmentation, predict_kwargs['score_cut'])\n", - " trial = pd.concat([pd.Series(config), trial])\n", - " metrics.append(trial)\n", - "\n", - "metrics = pd.DataFrame(metrics)" - ] - }, - { - "cell_type": "markdown", - "id": "dcfa5570-ada2-4102-aae0-a3830d304c5f", - "metadata": { - "id": "dcfa5570-ada2-4102-aae0-a3830d304c5f" - }, - "source": [ - "### Interpreting Output Metrics\n", - "\n", - "The key output metrics include:\n", - "- **`frac_assigned`**: The fraction of transcripts that were successfully assigned to a cell. A higher value indicates that the model is doing a good job associating transcripts with nuclei, which is a strong indicator of successful segmentation.\n", - "- **`cell_size_assigned`**: The average size of cells that have assigned transcripts. This helps assess how well the model is predicting cell boundaries, with unusually large or small values indicating potential issues with segmentation accuracy.\n", - "- **`cell_size_cc`**: The average size of connected components that were not assigned to a cell (i.e., nucleus-less regions). Large values here may suggest that transcripts are being incorrectly grouped together in the absence of a nucleus, which could indicate problems with the receptive field parameters or the segmentation process.\n", - "\n", - "These metrics illuminate the effectiveness of the model by highlighting both the success in associating transcripts with cells and potential areas where the model may need further tuning.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1a89aed4-c53b-460f-8a6f-f690920b6829", - "metadata": { - "execution": { - "iopub.status.busy": "2024-09-12T01:19:55.171961Z", - "iopub.status.idle": "2024-09-12T01:19:55.172161Z", - "shell.execute_reply": "2024-09-12T01:19:55.172071Z", - "shell.execute_reply.started": "2024-09-12T01:19:55.172062Z" - }, - "id": "1a89aed4-c53b-460f-8a6f-f690920b6829" - }, - "outputs": [], - "source": [ - "metrics" - ] + "cells": [ + { + "cell_type": "markdown", + "id": "21ed4db6-5234-46b1-9f38-b5883ac88946", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-11T22:22:55.404267Z", + "iopub.status.busy": "2024-09-11T22:22:55.403876Z", + "iopub.status.idle": "2024-09-11T22:22:58.089917Z", + "shell.execute_reply": "2024-09-11T22:22:58.089303Z", + "shell.execute_reply.started": "2024-09-11T22:22:55.404248Z" + }, + "id": "21ed4db6-5234-46b1-9f38-b5883ac88946" + }, + "source": [ + "# **Introduction to Segger**\n", + "\n", + "\n", + "**Important note (Dec 2024):** As segger is currently undergoing constant development we highly recommend installing directly via github.\n", + "\n", + "\n", + "Segger is a cutting-edge cell segmentation model specifically designed for **single-molecule resolved spatial omics** datasets. It addresses the challenge of accurately segmenting individual cells in complex imaging datasets, leveraging a unique approach based on graph neural networks (GNNs).\n", + "\n", + "The core idea behind Segger is to model both **nuclei** and **transcripts** as graph nodes, with edges connecting them based on their spatial proximity. This allows the model to learn from the co-occurrence of nucleic and cytoplasmic molecules, resulting in more refined and accurate cell boundaries. By using spatial information and GNNs, Segger achieves state-of-the-art performance in segmenting single cells in datasets such as 10X Xenium and MERSCOPE, outperforming traditional methods like Baysor and Cellpose.\n", + "\n", + "Segger's workflow consists of:\n", + "1. **Dataset creation**: Converting raw transcriptomic data into a graph-based dataset.\n", + "2. **Training**: Training the Segger model on the graph to learn cell boundaries.\n", + "3. **Prediction**: Using the trained model to make predictions on new datasets.\n", + "\n", + "This tutorial will guide you through each step of the process, ensuring you can train and apply Segger for your own data." + ] + }, + { + "cell_type": "markdown", + "id": "XEY6CTzK0648", + "metadata": { + "id": "XEY6CTzK0648" + }, + "source": [ + "Installing segger from the GitHub repository:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "TIQnPzfx08Zr", + "metadata": { + "id": "TIQnPzfx08Zr" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/EliHei2/segger_dev.git\n", + "%cd segger_dev\n", + "!pip install \".[rapids12]\" -q" + ] + }, + { + "cell_type": "markdown", + "id": "q3SNnImS09_N", + "metadata": { + "id": "q3SNnImS09_N" + }, + "source": [ + "Downloading the [Xenium Human Pancreatic Dataset](https://www.10xgenomics.com/products/xenium-human-pancreatic-dataset-explorer):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "Qjdt3f-U0_i9", + "metadata": { + "id": "Qjdt3f-U0_i9" + }, + "outputs": [], + "source": [ + "!mkdir data_xenium\n", + "%cd data_xenium\n", + "!wget https://cf.10xgenomics.com/samples/xenium/1.6.0/Xenium_V1_hPancreas_Cancer_Add_on_FFPE/Xenium_V1_hPancreas_Cancer_Add_on_FFPE_outs.zip\n", + "!unzip Xenium_V1_hPancreas_Cancer_Add_on_FFPE_outs.zip\n", + "%cd .." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "trM8h-Ek16sJ", + "metadata": { + "id": "trM8h-Ek16sJ" + }, + "outputs": [], + "source": [ + "from segger.data.parquet.sample import STSampleParquet\n", + "from segger.training.segger_data_module import SeggerDataModule\n", + "from segger.training.train import LitSegger\n", + "from segger.prediction.predict_parquet import segment, load_model\n", + "from lightning.pytorch.loggers import CSVLogger\n", + "from pytorch_lightning import Trainer\n", + "from pathlib import Path\n", + "import pandas as pd\n", + "from matplotlib import pyplot as plt\n", + "import seaborn as sns\n", + "import scanpy as sc" + ] + }, + { + "cell_type": "markdown", + "id": "db009015-c379-4f50-97ed-81dca9df28ac", + "metadata": { + "id": "db009015-c379-4f50-97ed-81dca9df28ac" + }, + "source": [ + "# **1. Create your Segger Dataset**\n", + "\n", + "In this step, we generate the dataset required for Segger's cell segmentation tasks.\n", + "\n", + "Segger relies on spatial transcriptomics data, combining staining **boundaries** (e.g., nuclei or membrane stainings) and **transcripts** from single-cell resolved imaging datasets. These nuclei and transcript nodes are represented in a graph, and the spatial proximity of transcripts to nuclei is used to establish edges between them.\n", + "\n", + "To use Segger with a Xenium dataset, you need the **`transcripts.parquet`** and **`nucleus_boundaries.parquet`** (or **`cell_boundaries.parquet`**, in case the Xenium samples comes with the segmentation kit) files. The **transcripts** file contains spatial coordinates and information for each transcript, while the **boundaries** file defines the polygon boundaries of the nuclei or cells. These files enable segger to map transcripts to their respective nuclei and perform cell segmentation based on spatial relationships. Segger can also be extended to other platforms by modifying the column names or formats in the input files to match its expected structure, making it adaptable for various spatial transcriptomics technologies. See [this](https://github.com/EliHei2/segger_dev/tree/main/src/segger/data/parquet/_settings) for Xenium settings." + ] + }, + { + "cell_type": "markdown", + "id": "9d2b090b", + "metadata": { + "id": "9d2b090b" + }, + "source": [ + "### **1.1. Fast Dataset Creation with segger**\n", + "\n", + "Segger introduces a fast and efficient pipeline for processing spatial transcriptomics data. This method accelerates dataset creation, particularly for large datasets, by using **ND-tree-based spatial partitioning** and **parallel processing**. This results in a much faster preparation of the dataset, which is saved in PyTorch Geometric (PyG) format, similar to the previous method.\n", + "\n", + "**Note**: The previous dataset creation method will soon be deprecated in favor of this optimized pipeline.\n", + "\n", + "The pipeline requires the following inputs:\n", + "\n", + "- **base_dir**: The directory containing the raw dataset.\n", + "- **data_dir**: The directory where the processed dataset (tiles in PyG format) will be saved.\n", + "\n", + "The core improvements in this method come from the use of **ND-tree partitioning**, which splits the data efficiently into spatial regions, and **parallel processing**, which speeds up the handling of these regions across multiple CPU cores. For example, using this pipeline, the Xenium Human Pancreatic Dataset can be processed in just a few minutes when running with 16 workers.\n", + "\n", + "Below is an example of how to create a dataset using the faster Segger pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e933ebf3", + "metadata": { + "id": "e933ebf3" + }, + "outputs": [], + "source": [ + "xenium_data_dir = Path('data_xenium')\n", + "segger_data_dir = Path('data_segger')\n", + "\n", + "sample = STSampleParquet(\n", + " base_dir=xenium_data_dir,\n", + " n_workers=4,\n", + " sample_type='xenium', # this could be 'xenium_v2' in case one uses the cell boundaries from the segmentation kit.\n", + " # weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available\n", + ")\n", + "\n", + "sample.save(\n", + " data_dir=segger_data_dir,\n", + " k_bd=3,\n", + " dist_bd=15.0,\n", + " k_tx=3,\n", + " dist_tx=5.0,\n", + " tile_width=120,\n", + " tile_height=120,\n", + " neg_sampling_ratio=5.0,\n", + " frac=1.0,\n", + " val_prob=0.1,\n", + " test_prob=0.2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6ab27f9a", + "metadata": { + "id": "6ab27f9a" + }, + "source": [ + "#### **Parameters**\n", + "Here is a complete list of parameters you can use to control the dataset creation process:\n", + "\n", + "- **--base_dir**: Directory containing the raw spatial transcriptomics dataset.\n", + "- **--data_dir**: Directory where the processed Segger dataset (in PyG format) will be saved.\n", + "- **--sample_type**: (Optional) Specifies the type of dataset (e.g., \"xenium\" or \"merscope\"). Defaults to None.\n", + "- **--scrnaseq_file**: Path to the scRNAseq file (default: None).\n", + "- **--celltype_column**: Column name for cell type annotations in the scRNAseq file (default: None).\n", + "- **--k_bd**: Number of nearest neighbors for boundary nodes (default: 3).\n", + "- **--dist_bd**: Maximum distance for boundary neighbors (default: 15.0).\n", + "- **--k_tx**: Number of nearest neighbors for transcript nodes (default: 3).\n", + "- **--dist_tx**: Maximum distance for transcript neighbors (default: 5.0).\n", + "- **--tile_size**: Specifies the size of the tile. If provided, it overrides both tile_width and tile_height.\n", + "- **--tile_width**: Width of the tiles in pixels (ignored if tile_size is provided).\n", + "- **--tile_height**: Height of the tiles in pixels (ignored if tile_size is provided).\n", + "- **--neg_sampling_ratio**: Ratio of negative samples (default: 5.0).\n", + "- **--frac**: Fraction of the dataset to process (default: 1.0).\n", + "- **--val_prob**: Proportion of data used for validation split (default: 0.1).\n", + "- **--test_prob**: Proportion of data used for testing split (default: 0.2).\n", + "- **--n_workers**: Number of workers for parallel processing (default: 1)." + ] + }, + { + "cell_type": "markdown", + "id": "70755046", + "metadata": {}, + "source": [ + "### **1.2. Using custom gene embeddings**\n", + "\n", + "In the default mode, segger initially tokenizes transcripts based on their gene type simply in a one-hot manner. However, one can use other genes embeddings (e.g., pre-trained embeddings). The following example shows how one can employ a cell-type-annotated scRNAseq reference of the same tissue type (not necessary same sample or experiment) to embed genes based on their abaundance in different cell types:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bf18259", + "metadata": {}, + "outputs": [], + "source": [ + "from segger.data.utils import calculate_gene_celltype_abundance_embedding\n", + "scrnaseq_file = Path('my_scRNAseq_file.h5ad')\n", + "celltype_column = 'celltype_column'\n", + "gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(\n", + " sc.read(scrnaseq_file),\n", + " celltype_column\n", + ")\n", + "\n", + "sample = STSampleParquet(\n", + " base_dir=xenium_data_dir,\n", + " n_workers=4,\n", + " sample_type='xenium', # this could be 'xenium_v2' in case one uses the cell boundaries from the segmentation kit.\n", + " weights=gene_celltype_abundance_embedding, \n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9962e4b8-4028-4683-9b75-d674fa6fb01d", + "metadata": { + "id": "9962e4b8-4028-4683-9b75-d674fa6fb01d" + }, + "source": [ + "# **2. Train your Segger Model**\n", + "\n", + "The Segger model training process begins after the dataset has been created. This model is a **heterogeneous graph neural network (GNN)** designed to segment single cells by leveraging both nuclei and transcript data.\n", + "\n", + "Segger uses graph attention layers to propagate information across nodes (nuclei and transcripts) and refine cell boundaries. The model architecture includes initial embedding layers, attention-based graph convolutions, and residual connections for stable learning.\n", + "\n", + "Segger leverages the **PyTorch Lightning** framework to streamline the training and evaluation of its graph neural network (GNN). PyTorch Lightning simplifies the training process by abstracting away much of the boilerplate code, allowing users to focus on model development and experimentation. It also supports multi-GPU training, mixed-precision, and efficient scaling, making it an ideal framework for training complex models like Segger." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4db89cb4-d0eb-426a-a71f-d127926fa412", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-12T00:49:07.236043Z", + "iopub.status.busy": "2024-09-12T00:49:07.235854Z", + "iopub.status.idle": "2024-09-12T00:49:08.351946Z", + "shell.execute_reply": "2024-09-12T00:49:08.351565Z", + "shell.execute_reply.started": "2024-09-12T00:49:07.236028Z" + }, + "id": "4db89cb4-d0eb-426a-a71f-d127926fa412" + }, + "outputs": [], + "source": [ + "# Base directory to store Pytorch Lightning models\n", + "models_dir = Path('models')\n", + "\n", + "# Initialize the Lightning data module\n", + "dm = SeggerDataModule(\n", + " data_dir=segger_data_dir,\n", + " batch_size=2,\n", + " num_workers=2,\n", + ")\n", + "\n", + "dm.setup()\n", + "\n", + "is_token_based = True\n", + "num_tx_tokens = 500\n", + "\n", + "# If you use custom gene embeddings, use the following two lines instead:\n", + "# is_token_based = False\n", + "# num_tx_tokens = dm.train[0].x_dict[\"tx\"].shape[1] # Set the number of tokens to the number of genes\n", + "\n", + "\n", + "num_bd_features = dm.train[0].x_dict[\"bd\"].shape[1]\n", + "\n", + "# Initialize the Lightning model\n", + "ls = LitSegger(\n", + " is_token_based = is_token_based,\n", + " num_node_features = {\"tx\": num_tx_tokens, \"bd\": num_bd_features},\n", + " init_emb=8, \n", + " hidden_channels=64,\n", + " out_channels=16,\n", + " heads=4,\n", + " num_mid_layers=1,\n", + " aggr='sum',\n", + ")\n", + "\n", + "# Initialize the Lightning trainer\n", + "trainer = Trainer(\n", + " accelerator='cuda',\n", + " strategy='auto',\n", + " precision='16-mixed',\n", + " devices=1, # set higher number if more gpus are available\n", + " max_epochs=100,\n", + " default_root_dir=models_dir,\n", + " logger=CSVLogger(models_dir),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "207864b8-7e52-4add-a4a2-e95a4debdc06", + "metadata": { + "id": "207864b8-7e52-4add-a4a2-e95a4debdc06", + "scrolled": true + }, + "outputs": [], + "source": [ + "# Fit model\n", + "trainer.fit(\n", + " model=ls,\n", + " datamodule=dm\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e6214a79", + "metadata": {}, + "source": [ + "Key parameters for training:\n", + "- **`--data_dir`**: Directory containing the training data.\n", + "- **`--model_dir`**: Directory in which to store models.\n", + "- **`--epochs`**: Specifies the number of training epochs.\n", + "- **`--batch_size`**: Batch sizes for training and validation data.\n", + "- **`--learning_rate`**: The initial learning rate for the optimizer.\n", + "- **`--hidden_channels`**: Number of hidden channels in the GNN layers.\n", + "- **`--heads`**: Number of attention heads used in each graph convolutional layer.\n", + "- **`--init_emb`**: Sets the dimensionality of the initial embeddings applied to the input node features (e.g., transcripts). A higher embedding dimension may capture more feature complexity but also requires more computation.\n", + "- **`--out_channels`**: Specifies the number of output channels after the final graph attention layer, e.g. the final learned representations of the graph nodes.\n", + "\n", + "Additional Options for Training the Segger Model:\n", + "\n", + "- **`--aggr`**: This option controls the aggregation method used in the graph convolution layers.\n", + "- **`--accelerator`**: Controls the hardware used for training, such as `cuda` for GPU training. This enables Segger to leverage GPU resources for faster training, especially useful for large datasets.\n", + "- **`--strategy`**: Defines the distributed training strategy, with `auto` allowing PyTorch Lightning to automatically configure the best strategy based on the hardware setup.\n", + "- **`--precision`**: Enables mixed precision training (e.g., `16-mixed`), which can speed up training and reduce memory usage while maintaining accuracy." + ] + }, + { + "cell_type": "markdown", + "id": "9a7d20c6-ca16-4beb-b627-afb41e3fb491", + "metadata": { + "id": "9a7d20c6-ca16-4beb-b627-afb41e3fb491" + }, + "source": [ + "### *Troubleshooting #1*\n", + "\n", + "In the cell below, we are visualizing key metrics from the model training and validation process. The plot displays **training loss**, **validation loss**, **F1 validation score**, and **AUROC validation score** over training steps. We expect to see the loss curves decreasing over time, signaling the model's improvement, and the F1 and AUROC scores increasing, reflecting improved segmentation performance as the model learns.\n", + "\n", + "If training is not working effectively, you might observe the following in the plot displaying **training loss**, **validation loss**, **F1 score**, and **AUROC**:\n", + "\n", + "- **Training loss not decreasing**: If the training loss remains high or fluctuates without a consistent downward trend, this indicates that the model is not learning effectively from the training data.\n", + "- **Validation loss decreases, then increases**: If validation loss decreases initially but starts to increase while training loss continues to drop, this could be a sign of **overfitting**, where the model is performing well on the training data but not generalizing to the validation data.\n", + "- **F1 score and AUROC not improving**: If these metrics remain flat or show inconsistent improvement, the model may be struggling to correctly segment cells or classify transcripts, indicating an issue with learning performance.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43a9c1a4-3898-407d-ac0f-f98b13694593", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-11T22:06:58.182616Z", + "iopub.status.busy": "2024-09-11T22:06:58.182357Z", + "iopub.status.idle": "2024-09-11T22:07:01.063645Z", + "shell.execute_reply": "2024-09-11T22:07:01.063184Z", + "shell.execute_reply.started": "2024-09-11T22:06:58.182599Z" + }, + "id": "43a9c1a4-3898-407d-ac0f-f98b13694593", + "outputId": "70ba8e1b-7814-497a-c8b6-8aa7295cd4d9" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 0, 'Step')" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" } - ], - "metadata": { - "accelerator": "GPU", + ], + "source": [ + "# Evaluate results\n", + "model_version = 0 # 'v_num' from training output above\n", + "model_path = models_dir / 'lightning_logs' / f'version_{model_version}'\n", + "metrics = pd.read_csv(model_path / 'metrics.csv', index_col=1)\n", + "\n", + "fig, ax = plt.subplots(1,1, figsize=(2,2))\n", + "\n", + "for col in metrics.columns.difference(['epoch']):\n", + " metric = metrics[col].dropna()\n", + " ax.plot(metric.index, metric.values, label=col)\n", + "\n", + "ax.legend(loc=(1, 0.33))\n", + "ax.set_ylim(0, 1)\n", + "ax.set_xlabel('Step')" + ] + }, + { + "cell_type": "markdown", + "id": "e73687e1-ee8f-46e9-8bd2-1ddc571ef94b", + "metadata": { + "id": "e73687e1-ee8f-46e9-8bd2-1ddc571ef94b" + }, + "source": [ + "# **3. Make Predictions**\n", + "\n", + "Once the Segger model is trained, it can be used to make predictions on seen (partially trained) data or be transfered to unseen data. This step involves using a trained checkpoint to predict cell boundaries and refine transcript-nuclei associations." + ] + }, + { + "cell_type": "markdown", + "id": "9807abf3", + "metadata": { + "id": "9807abf3" + }, + "source": [ + "#### **Requirements for the Faster Prediction Pipeline**\n", + "The pipeline requires the following inputs:\n", + "\n", + "- **segger_data_dir**: The directory containing the processed Segger dataset (in PyG format).\n", + "- **models_dir**: The directory containing the trained Segger model checkpoints.\n", + "- **benchmarks_dir**: The directory where the segmentation results will be saved.\n", + "- **transcripts_file**: Path to the file containing the transcript data for prediction.\n", + "\n", + "#### **Running the Faster Prediction Pipeline**\n", + "Below is an example of how to run the faster Segger prediction pipeline using the command line:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "PEOtAs-t9CiY", + "metadata": { "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" + "base_uri": "https://localhost:8080/" + }, + "id": "PEOtAs-t9CiY", + "outputId": "8b7a5375-9ebc-4bb4-9421-254410319120" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting segmentation for segger_embedding_1001...\n" + ] } + ], + "source": [ + "dm = SeggerDataModule(\n", + " data_dir='data_segger',\n", + " batch_size=1,\n", + " num_workers=4,\n", + ")\n", + "\n", + "dm.setup()\n", + "\n", + "model_version = 0\n", + "model_path = Path('models') / \"lightning_logs\" / f\"version_{model_version}\"\n", + "model = load_model(model_path / \"checkpoints\")\n", + "\n", + "receptive_field = {'k_bd': 4, 'dist_bd': 12, 'k_tx': 15, 'dist_tx': 3}\n", + "\n", + "segment(\n", + " model,\n", + " dm,\n", + " save_dir='benchmarks',\n", + " seg_tag='segger_output',\n", + " transcript_file='data_xenium/transcripts.parquet',\n", + " receptive_field=receptive_field,\n", + " min_transcripts=5,\n", + " cell_id_col='segger_cell_id',\n", + " use_cc=False,\n", + " knn_method='cuda',\n", + " verbose=True,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "3aa5002c", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "id": "0a823035", + "metadata": { + "id": "0a823035" + }, + "source": [ + "#### **Parameters**\n", + "Here is a detailed explanation of each parameter used in the faster prediction pipeline:\n", + "\n", + "- **--segger_data_dir**: The directory containing the processed Segger dataset, saved as PyTorch Geometric data objects, that will be used for prediction.\n", + "- **--models_dir**: The directory containing the trained Segger model checkpoints. These checkpoints store the learned weights required for making predictions.\n", + "- **--benchmarks_dir**: The directory where the segmentation results will be saved.\n", + "- **--transcripts_file**: Path to the *transcripts.parquet* file.\n", + "- **--batch_size**: Specifies the batch size for processing during prediction. Larger batch sizes speed up inference but use more memory (default: 1).\n", + "- **--num_workers**: Number of workers to use for parallel data loading (default: 1).\n", + "- **--model_version**: Version of the trained model to load for predictions, based on the version number from the training logs (default: 0).\n", + "- **--save_tag**: A tag used to name and organize the segmentation results (default: segger_embedding).\n", + "- **--min_transcripts**: The minimum number of transcripts required for segmentation (default: 5).\n", + "- **--cell_id_col**: The name of the column that stores the cell IDs (default: segger_cell_id).\n", + "- **--use_cc**: Enables the use of connected components (CC) for grouping transcripts that are not associated with any nucleus (default: False).\n", + "- **--knn_method**: Method for KNN (K-Nearest Neighbors) computation. Only option is \"cuda\" for this pipeline (default: cuda).\n", + "- **--file_format**: The format for saving the output segmentation data. Only option is \"anndata\" for this pipeline (default: anndata).\n", + "- **--k_bd**: Number of nearest neighbors for boundary nodes during segmentation (default: 4).\n", + "- **--dist_bd**: Maximum distance for boundary nodes during segmentation (default: 12.0).\n", + "- **--k_tx**: Number of nearest neighbors for transcript nodes during segmentation (default: 5).\n", + "- **--dist_tx**: Maximum distance for transcript nodes during segmentation (default: 5.0)." + ] + }, + { + "cell_type": "markdown", + "id": "b0917be9-4e82-4ba5-869d-5a9203721699", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-11T23:06:23.977884Z", + "iopub.status.busy": "2024-09-11T23:06:23.977517Z" + }, + "id": "b0917be9-4e82-4ba5-869d-5a9203721699" + }, + "source": [ + "### *Troubleshooting #2*\n", + "\n", + "In the cell below, we are visualizing the distribution of **Segger similarity scores** using a histogram. The **Segger similarity score** reflects how closely transcripts are associated with their respective nuclei in the segmentation process. **Higher scores** indicate stronger associations between transcripts and their nuclei, suggesting more accurate cell boundaries. **Lower scores** might indicate weaker associations, which could highlight potential segmentation errors or challenging regions in the data. We expect to see a large number of the scores clustering toward higher values, which would indicate strong overall performance of the model in associating transcripts with nuclei.\n", + "\n", + "The following would indicate potential issues with the model's predictions:\n", + "\n", + "- **A very large portion of scores near zero**: If many scores are concentrated at the lower end of the scale, this suggests that the model is frequently failing to associate transcripts with their corresponding nuclei, indicating poor segmentation quality.\n", + "- **No clear peak in the distribution**: If the histogram is flat or shows a wide, spread-out distribution, this could indicate that the model is struggling to consistently assign similarity scores, which may be a sign that the training process did not optimize the model correctly.\n", + "\n", + "Both cases would suggest that the model requires further tuning, such as adjusting hyperparameters, data preprocessing, or the training procedure (see below)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a450d3ca-2876-4f48-be89-761147b17387", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-11T22:07:04.216273Z", + "iopub.status.busy": "2024-09-11T22:07:04.215965Z", + "iopub.status.idle": "2024-09-11T22:07:08.177601Z", + "shell.execute_reply": "2024-09-11T22:07:08.177158Z", + "shell.execute_reply.started": "2024-09-11T22:07:04.216257Z" + }, + "id": "a450d3ca-2876-4f48-be89-761147b17387", + "outputId": "0576e0b8-4823-4701-b661-dc6b513841f2" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(1,1, figsize=(2,2))\n", + "sns.histplot(\n", + " segmentation['score'],\n", + " bins=50,\n", + " ax=ax,\n", + ")\n", + "ax.set_ylabel('Count')\n", + "ax.set_xlabel('Segger Similarity Score')\n", + "ax.set_yscale('log')" + ] + }, + { + "cell_type": "markdown", + "id": "5492fb96-bf8e-49d5-b40e-7e6b3f871bbe", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-11T22:34:15.990223Z", + "iopub.status.busy": "2024-09-11T22:34:15.988880Z" + }, + "id": "5492fb96-bf8e-49d5-b40e-7e6b3f871bbe" + }, + "source": [ + "#### The Importance of the Receptive Field in Segger\n", + "\n", + "The **receptive field** is a critical parameter in Segger, as it directly influences how the model interprets the spatial relationships between **transcripts** and **nuclei**. In the context of spatial transcriptomics, the receptive field determines the size of the neighborhood that each node (representing transcripts or nuclei) can \"see\" during graph construction and model training. Segger is particularly sensitive to the size of the receptive field because it affects the model's ability to propagate information across the graph. If the receptive field is too small, the model may fail to capture sufficient context for correct cell boundary delineation. Conversely, a very large receptive field may introduce noise by linking unrelated or distant nodes, reducing segmentation accuracy.\n", + "\n", + "#### Parameters affecting the receptive field in Segger:\n", + "- **`--r`**: This parameter defines the radius used when connecting transcripts to nuclei. A larger `r` expands the receptive field, linking more distant nodes. Fine-tuning this parameter helps ensure that Segger captures the right level of spatial interaction in the dataset.\n", + "- **`--k_bd` and `--k_tx`**: These control the number of nearest neighbors (nuclei and transcripts, respectively) considered in the graph. By increasing these values, the receptive field is effectively broadened, allowing more nodes to contribute to the information propagation.\n", + "- **`--dist_bd` and `--dist_tx`**: These parameters specify the maximum distances used to connect nuclei (`dist_bd`) and transcripts (`dist_tx`) to their neighbors during graph construction. They directly affect the receptive field by defining the cut-off distance for forming edges in the graph. Larger distance values expand the receptive field, connecting nodes that are further apart spatially. Careful tuning of these values is necessary to ensure that Segger captures relevant spatial relationships without introducing noise." + ] + }, + { + "cell_type": "markdown", + "id": "7ece1ac0-0708-45e2-87fc-1b25782831f8", + "metadata": { + "id": "7ece1ac0-0708-45e2-87fc-1b25782831f8" + }, + "source": [ + "# **4. Tune Parameters**" + ] + }, + { + "cell_type": "markdown", + "id": "896b8288-5287-4d10-a206-e68c0e4731c6", + "metadata": { + "id": "896b8288-5287-4d10-a206-e68c0e4731c6" + }, + "source": [ + "### Evaluating Receptive Field Parameters with Grid Search\n", + "\n", + "To evaluate the impact of different receptive field parameters in Segger, we use a **grid search** approach. The parameters `k_bd`, `k_tx`, `dist_bd`, and `dist_tx` (which control the number of neighbors and distances for nuclei and transcripts) are explored through various configurations defined in `param_space`. Each combination of these parameters is passed to the `trainable` function, which creates the dataset, trains the model, and makes predictions based on the specified receptive field.\n", + "\n", + "For each parameter combination:\n", + "1. A dataset is created with the specified receptive field.\n", + "2. The Segger model is trained on this dataset.\n", + "3. Predictions are made, and segmentation results are evaluated using the custom `evaluate` function. This function computes metrics like the fraction of assigned transcripts and average cell sizes.\n", + "\n", + "The results from each configuration are saved, allowing us to compare how different receptive field settings impact the model’s performance. This process enables a thorough search of the parameter space, optimizing the model for accurate segmentation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0c1a7a8-acb2-4aae-8ae4-8aa9a4196717", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-12T01:10:47.781418Z", + "iopub.status.busy": "2024-09-12T01:10:47.781067Z", + "iopub.status.idle": "2024-09-12T01:10:48.706615Z", + "shell.execute_reply": "2024-09-12T01:10:48.706194Z", + "shell.execute_reply.started": "2024-09-12T01:10:47.781401Z" + }, + "id": "b0c1a7a8-acb2-4aae-8ae4-8aa9a4196717" + }, + "outputs": [], + "source": [ + "import itertools\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0bd0803c-e58d-4f43-9627-d2c1ab187d5e", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-12T01:16:31.976312Z", + "iopub.status.busy": "2024-09-12T01:16:31.975947Z", + "iopub.status.idle": "2024-09-12T01:16:33.168389Z", + "shell.execute_reply": "2024-09-12T01:16:33.167956Z", + "shell.execute_reply.started": "2024-09-12T01:16:31.976295Z" + }, + "id": "0bd0803c-e58d-4f43-9627-d2c1ab187d5e" + }, + "outputs": [], + "source": [ + "tuning_dir = Path('path/to/tutorial/tuning/')\n", + "sampling_rate = 0.125" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b879a0b5-150c-4240-99ec-81075855aa52", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-12T01:16:33.169525Z", + "iopub.status.busy": "2024-09-12T01:16:33.169189Z", + "iopub.status.idle": "2024-09-12T01:16:34.147222Z", + "shell.execute_reply": "2024-09-12T01:16:34.146804Z", + "shell.execute_reply.started": "2024-09-12T01:16:33.169508Z" + }, + "id": "b879a0b5-150c-4240-99ec-81075855aa52", + "jupyter": { + "source_hidden": true + } + }, + "outputs": [], + "source": [ + "# Fixed function arguments used for each trial\n", + "transcripts_path = xenium_data_dir / 'transcripts.parquet'\n", + "\n", + "boundaries_path = xenium_data_dir / 'nucleus_boundaries.parquet'\n", + "\n", + "dataset_kwargs = dict(\n", + " x_size=80, y_size=80, d_x=80, d_y=80, margin_x=10, margin_y=10,\n", + " num_workers=4, sampling_rate=sampling_rate,\n", + ")\n", + "\n", + "model_kwargs = dict(\n", + " metadata=(['tx', 'bd'], [('tx', 'belongs', 'bd'), ('tx', 'neighbors', 'tx')]),\n", + " num_tx_tokens=500, init_emb=8, hidden_channels=32, out_channels=8,\n", + " heads=2, num_mid_layers=2, aggr='sum',\n", + ")\n", + "\n", + "trainer_kwargs = dict(\n", + " accelerator='cuda', strategy='auto', precision='16-mixed', devices=1,\n", + " max_epochs=100,\n", + ")\n", + "\n", + "predict_kwargs = dict(score_cut=0.2, use_cc=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbd831c9-3a50-4e3b-97d3-3c152ae01188", + "metadata": { + "id": "fbd831c9-3a50-4e3b-97d3-3c152ae01188", + "jupyter": { + "source_hidden": true + } + }, + "outputs": [], + "source": [ + "def trainable(config):\n", + "\n", + " receptive_field = {k: config[k] for k in ['k_bd', 'k_tx', 'dist_bd', 'dist_tx']}\n", + "\n", + " # Dataset creation\n", + " xs = XeniumSample(verbose=False)\n", + " xs.set_file_paths(transcripts_path, boundaries_path)\n", + " xs.set_metadata()\n", + " try:\n", + " xs.save_dataset_for_segger(\n", + " processed_dir=config['data_dir'],\n", + " receptive_field=receptive_field,\n", + " **dataset_kwargs,\n", + " )\n", + " except:\n", + " pass\n", + "\n", + " # Model training\n", + " ls = LitSegger(**model_kwargs)\n", + " dm = SeggerDataModule(\n", + " data_dir=config['data_dir'],\n", + " batch_size=2,\n", + " num_workers=dataset_kwargs['num_workers'],\n", + " )\n", + " trainer = Trainer(\n", + " default_root_dir=config['model_dir'],\n", + " logger=CSVLogger(config['model_dir']),\n", + " **trainer_kwargs,\n", + " )\n", + " trainer.fit(model=ls, datamodule=dm)\n", + "\n", + " segmentation = predict(\n", + " load_model(config['model_dir']/'lightning_logs/version_0/checkpoints'),\n", + " dm.train_dataloader(),\n", + " receptive_field=receptive_field,\n", + " **predict_kwargs,\n", + " )\n", + "\n", + " metrics = evaluate(segmentation)\n", + "\n", + "\n", + "def evaluate(segmentation: pd.DataFrame, score_cut: float) -> pd.Series:\n", + "\n", + " assigned = segmentation['score'] > score_cut\n", + " metrics = pd.Series(dtype=float)\n", + " metrics['frac_assigned'] = assigned.mean()\n", + " cell_sizes = segmentation.groupby(assigned)['segger_cell_id'].value_counts()\n", + " assigned_avg = 0 if True not in cell_sizes.index else cell_sizes[True].mean()\n", + " cc_avg = 0 if False not in cell_sizes.index else cell_sizes[False].mean()\n", + " metrics['cell_size_assigned'] = assigned_avg\n", + " metrics['cell_size_cc'] = cc_avg\n", + " return metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba2dcc9a-3a06-4b84-a487-59a768eed5d5", + "metadata": { + "execution": { + "iopub.execute_input": "2024-09-12T01:16:35.184598Z", + "iopub.status.busy": "2024-09-12T01:16:35.184180Z", + "iopub.status.idle": "2024-09-12T01:19:55.171470Z", + "shell.execute_reply": "2024-09-12T01:19:55.170810Z", + "shell.execute_reply.started": "2024-09-12T01:16:35.184582Z" + }, + "id": "ba2dcc9a-3a06-4b84-a487-59a768eed5d5", + "scrolled": true + }, + "outputs": [], + "source": [ + "param_space = {\n", + " \"k_bd\": [3, 5, 10],\n", + " \"dist_bd\": [5, 10, 15, 20],\n", + " \"k_tx\": [3, 5, 10],\n", + " \"dist_tx\": [3, 5, 10],\n", + "}\n", + "\n", + "metrics = []\n", + "\n", + "for params in itertools.product(*param_space.values()):\n", + "\n", + " config = dict(zip(param_space.keys(), params))\n", + "\n", + " # Setup directories\n", + " trial_dir = tuning_dir / '_'.join([f'{k}={v}' for k, v in config.items()])\n", + "\n", + " data_dir = trial_dir / 'segger_data'\n", + " data_dir.mkdir(exist_ok=True, parents=True)\n", + " config['data_dir'] = data_dir\n", + "\n", + " model_dir = trial_dir / 'models'\n", + " model_dir.mkdir(exist_ok=True, parents=True)\n", + " config['model_dir'] = model_dir\n", + "\n", + " segmentation = trainable(config)\n", + " trial = evaluate(segmentation, predict_kwargs['score_cut'])\n", + " trial = pd.concat([pd.Series(config), trial])\n", + " metrics.append(trial)\n", + "\n", + "metrics = pd.DataFrame(metrics)" + ] + }, + { + "cell_type": "markdown", + "id": "dcfa5570-ada2-4102-aae0-a3830d304c5f", + "metadata": { + "id": "dcfa5570-ada2-4102-aae0-a3830d304c5f" + }, + "source": [ + "### Interpreting Output Metrics\n", + "\n", + "The key output metrics include:\n", + "- **`frac_assigned`**: The fraction of transcripts that were successfully assigned to a cell. A higher value indicates that the model is doing a good job associating transcripts with nuclei, which is a strong indicator of successful segmentation.\n", + "- **`cell_size_assigned`**: The average size of cells that have assigned transcripts. This helps assess how well the model is predicting cell boundaries, with unusually large or small values indicating potential issues with segmentation accuracy.\n", + "- **`cell_size_cc`**: The average size of connected components that were not assigned to a cell (i.e., nucleus-less regions). Large values here may suggest that transcripts are being incorrectly grouped together in the absence of a nucleus, which could indicate problems with the receptive field parameters or the segmentation process.\n", + "\n", + "These metrics illuminate the effectiveness of the model by highlighting both the success in associating transcripts with cells and potential areas where the model may need further tuning.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a89aed4-c53b-460f-8a6f-f690920b6829", + "metadata": { + "execution": { + "iopub.status.busy": "2024-09-12T01:19:55.171961Z", + "iopub.status.idle": "2024-09-12T01:19:55.172161Z", + "shell.execute_reply": "2024-09-12T01:19:55.172071Z", + "shell.execute_reply.started": "2024-09-12T01:19:55.172062Z" + }, + "id": "1a89aed4-c53b-460f-8a6f-f690920b6829" + }, + "outputs": [], + "source": [ + "metrics" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 5 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } From 7e6b4d8b0f26cba4c54b6139718aaabb7a784c70 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 00:19:03 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 6 +- scripts/create_data_fast_sample.py | 74 +++++------- scripts/predict_model_sample.py | 6 +- scripts/sand.py | 142 +++++++++++------------ scripts/train_model_sample.py | 2 +- src/segger/data/Untitled-4.py | 34 +++--- src/segger/data/parquet/_utils.py | 1 + src/segger/prediction/predict_parquet.py | 6 +- 8 files changed, 121 insertions(+), 150 deletions(-) diff --git a/README.md b/README.md index 13d7f3c..4fb0e83 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,8 @@ [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/EliHei2/segger_dev/main.svg)](https://results.pre-commit.ci/latest/github/EliHei2/segger_dev/main) - - **Important note (Dec 2024)**: As segger is currently undergoing constant development we highly recommending installing segger directly via github. - **segger** is a cutting-edge tool for **cell segmentation** in **single-molecule spatial omics** datasets. By leveraging **graph neural networks (GNNs)** and heterogeneous graphs, segger offers unmatched accuracy and scalability. # How segger Works @@ -52,7 +49,7 @@ segger tackles these with a **graph-based approach**, achieving superior segment --- -## Installation +## Installation **Important note (Dec 2024)**: As segger is currently undergoing constant development we highly recommending installing segger directly via github. @@ -78,7 +75,6 @@ pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv - Afterwards choose the installation method that best suits your needs. - ### GitHub Installation For a straightforward local installation from GitHub, clone the repository and install the package using `pip`: diff --git a/scripts/create_data_fast_sample.py b/scripts/create_data_fast_sample.py index 54b8845..573dd1e 100644 --- a/scripts/create_data_fast_sample.py +++ b/scripts/create_data_fast_sample.py @@ -7,87 +7,77 @@ import numpy as np from segger.data.parquet._utils import get_polygons_from_xy -xenium_data_dir = Path('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/') -segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb') +xenium_data_dir = Path("data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/") +segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb") -scrnaseq_file = Path('/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad') -celltype_column = 'celltype_major' -gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding( - sc.read(scrnaseq_file), - celltype_column -) +scrnaseq_file = Path("/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad") +celltype_column = "celltype_major" +gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(sc.read(scrnaseq_file), celltype_column) sample = STSampleParquet( base_dir=xenium_data_dir, n_workers=4, - sample_type='xenium', - weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available + sample_type="xenium", + weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available ) -transcripts = pd.read_parquet( - xenium_data_dir / 'transcripts.parquet', - filters=[[('overlaps_nucleus', '=', 1)]] -) -boundaries = pd.read_parquet(xenium_data_dir / 'nucleus_boundaries.parquet') +transcripts = pd.read_parquet(xenium_data_dir / "transcripts.parquet", filters=[[("overlaps_nucleus", "=", 1)]]) +boundaries = pd.read_parquet(xenium_data_dir / "nucleus_boundaries.parquet") -sizes = transcripts.groupby('cell_id').size() -polygons = get_polygons_from_xy(boundaries, 'vertex_x', 'vertex_y', 'cell_id') +sizes = transcripts.groupby("cell_id").size() +polygons = get_polygons_from_xy(boundaries, "vertex_x", "vertex_y", "cell_id") densities = polygons[sizes.index].area / sizes bd_width = polygons.minimum_bounding_radius().median() * 2 # 1/4 median boundary diameter dist_tx = bd_width / 4 # 90th percentile density of bounding circle with radius=dist_tx -k_tx = math.ceil(np.quantile(dist_tx ** 2 * np.pi * densities, 0.9)) +k_tx = math.ceil(np.quantile(dist_tx**2 * np.pi * densities, 0.9)) print(k_tx) print(dist_tx) sample.save( - data_dir=segger_data_dir, - k_bd=3, - dist_bd=15.0, - k_tx=dist_tx, - dist_tx=k_tx, - tile_width=120, - tile_height=120, - neg_sampling_ratio=5.0, - frac=1.0, - val_prob=0.1, - test_prob=0.1, + data_dir=segger_data_dir, + k_bd=3, + dist_bd=15.0, + k_tx=dist_tx, + dist_tx=k_tx, + tile_width=120, + tile_height=120, + neg_sampling_ratio=5.0, + frac=1.0, + val_prob=0.1, + test_prob=0.1, ) -xenium_data_dir = Path('data_tidy/bc_5k') -segger_data_dir = Path('data_tidy/pyg_datasets/bc_5k_emb') - +xenium_data_dir = Path("data_tidy/bc_5k") +segger_data_dir = Path("data_tidy/pyg_datasets/bc_5k_emb") sample = STSampleParquet( base_dir=xenium_data_dir, n_workers=1, - sample_type='xenium', - weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available + sample_type="xenium", + weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available ) -transcripts = pd.read_parquet( - xenium_data_dir / 'transcripts.parquet', - filters=[[('overlaps_nucleus', '=', 1)]] -) -boundaries = pd.read_parquet(xenium_data_dir / 'nucleus_boundaries.parquet') +transcripts = pd.read_parquet(xenium_data_dir / "transcripts.parquet", filters=[[("overlaps_nucleus", "=", 1)]]) +boundaries = pd.read_parquet(xenium_data_dir / "nucleus_boundaries.parquet") -sizes = transcripts.groupby('cell_id').size() -polygons = get_polygons_from_xy(boundaries, 'vertex_x', 'vertex_y', 'cell_id') +sizes = transcripts.groupby("cell_id").size() +polygons = get_polygons_from_xy(boundaries, "vertex_x", "vertex_y", "cell_id") densities = polygons[sizes.index].area / sizes bd_width = polygons.minimum_bounding_radius().median() * 2 # 1/4 median boundary diameter dist_tx = bd_width / 4 # 90th percentile density of bounding circle with radius=dist_tx -k_tx = math.ceil(np.quantile(dist_tx ** 2 * np.pi * densities, 0.9)) +k_tx = math.ceil(np.quantile(dist_tx**2 * np.pi * densities, 0.9)) print(k_tx) print(dist_tx) diff --git a/scripts/predict_model_sample.py b/scripts/predict_model_sample.py index 9624abc..c1bc9c0 100644 --- a/scripts/predict_model_sample.py +++ b/scripts/predict_model_sample.py @@ -16,12 +16,11 @@ import dask.dataframe as dd - seg_tag = "bc_fast_data_emb_major" model_version = 1 -segger_data_dir = Path('data_tidy/pyg_datasets') / seg_tag -models_dir = Path("./models") / seg_tag +segger_data_dir = Path("data_tidy/pyg_datasets") / seg_tag +models_dir = Path("./models") / seg_tag benchmarks_dir = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc") transcripts_file = "data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet" # Initialize the Lightning data module @@ -58,4 +57,3 @@ gpu_ids=["0"], # client=client ) - diff --git a/scripts/sand.py b/scripts/sand.py index 447c5e6..ab480c0 100644 --- a/scripts/sand.py +++ b/scripts/sand.py @@ -32,7 +32,7 @@ def process_group(group, area_low, area_high): cell_boundary = generate_boundary(seg_cell) if isinstance(cell_boundary, MultiPolygon): # polygons = sorted(cell_boundary, key=lambda p: p.area, reverse=True) - # Extract the largest polygon (by area) from MultiPolygon + # Extract the largest polygon (by area) from MultiPolygon # cell_boundary = polygons[0] # cell_boundary = unary_union(cell_boundary) # polygons = sorted(cell_boundary, key=lambda p: p.area, reverse=True) @@ -72,7 +72,6 @@ def process_group(group, area_low, area_high): }, "seg_mask_value": uint_cell_id, } - def get_coordinates(boundary): @@ -86,8 +85,8 @@ def get_coordinates(boundary): elif isinstance(boundary, Polygon): # Return coordinates from a single Polygon return list(boundary.exterior.coords) - - + + def get_flatten_version(polygon_vertices, max_value=21): """ Flattens and standardizes the shape of polygon vertices to a fixed length. @@ -108,8 +107,11 @@ def get_flatten_version(polygon_vertices, max_value=21): flattened.append(vertices + [(0.0, 0.0)] * (max_value - len(vertices))) return np.array(flattened, dtype=np.float32) + from segger.validation.xenium_explorer import * from segger.prediction.boundary import * + + def seg2explorer( seg_df: pd.DataFrame, source_path: str, @@ -155,18 +157,13 @@ def seg2explorer( cell_id = [res["uint_cell_id"] for res in results] cell_summary = [res["cell_summary"] for res in results] # print('********************************1') - polygon_num_vertices = [ - len(res["cell_boundary"].exterior.coords) for res in results - ] + polygon_num_vertices = [len(res["cell_boundary"].exterior.coords) for res in results] print(polygon_num_vertices) # polygon_vertices = np.array( # [list(res["cell_boundary"].exterior.coords) for res in results], # dtype=object # ) - polygon_vertices = np.array( - [list(res["cell_boundary"].exterior.coords) for res in results], - dtype=object - ) + polygon_vertices = np.array([list(res["cell_boundary"].exterior.coords) for res in results], dtype=object) # polygon_vertices = get_flatten_version( # [get_coordinates(res["cell_boundary"]) for res in results], # max_value=21, @@ -182,7 +179,7 @@ def seg2explorer( "polygon_vertices": np.array(polygon_vertices).astype(np.float32), "seg_mask_value": np.array(seg_mask_value, dtype=np.int32), } - print(len(cells['cell_id'])) + print(len(cells["cell_id"])) # Save cells data existing_store = zarr.open(source_path / "cells.zarr.zip", mode="r") new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") @@ -194,7 +191,7 @@ def seg2explorer( new_store.attrs["number_cells"] = len(cells["cell_id"]) print(new_store) new_store.store.close() - + print(cells["polygon_vertices"]) # # Save analysis data if analysis_df is None: @@ -242,8 +239,6 @@ def seg2explorer( ) - - def seg2explorer( seg_df: pd.DataFrame, source_path: str, @@ -274,6 +269,7 @@ def seg2explorer( """ import zarr import json + source_path = Path(source_path) storage = Path(output_dir) cell_id2old_id = {} @@ -297,9 +293,9 @@ def seg2explorer( cell_convex_hull = get_boundary(seg_cell) # print(cell_convex_hull) if isinstance(cell_convex_hull, MultiPolygon): - # polygons = sorted(cell_boundary, key=lambda p: p.area, reverse=True) + # polygons = sorted(cell_boundary, key=lambda p: p.area, reverse=True) # Extract the largest polygon (by area) from MultiPolygon - # cell_boundary = polygons[0] + # cell_boundary = polygons[0] # cell_convex_hull = unary_union(cell_convex_hull) # print('****************1') # polygons = sorted(cell_convex_hull.geoms, key=lambda p: p.area, reverse=True) @@ -404,100 +400,96 @@ def seg2explorer( cells_name=cells_filename, analysis_name=analysis_filename, ) - - -ddf = dd.read_parquet('/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc/parquet_train_big_0.5_False_3_10_5_3_20241030/segger_transcripts.parquet').compute() +ddf = dd.read_parquet( + "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc/parquet_train_big_0.5_False_3_10_5_3_20241030/segger_transcripts.parquet" +).compute() ddf = ddf.dropna() -ddf = ddf[ddf.segger_cell_id != 'None'] -ddf = ddf.sort_values('segger_cell_id') +ddf = ddf[ddf.segger_cell_id != "None"] +ddf = ddf.sort_values("segger_cell_id") df = ddf.iloc[:10000, :] - -df_path = Path('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep1_v9_segger.csv.gz') +df_path = Path("data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep1_v9_segger.csv.gz") df_v9 = dd.read_csv(df_path) -df_main = dd.read_parquet('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/transcripts.parquet') +df_main = dd.read_parquet("data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/transcripts.parquet") -ddf = df_v9.merge(df_main, on='transcript_id') +ddf = df_v9.merge(df_main, on="transcript_id") ddf = ddf.compute() -ddf = ddf[ddf.segger_cell_id != 'None'] -ddf = ddf.sort_values('segger_cell_id') -df = ddf.loc[(ddf.x_location > 250) & (ddf.x_location < 1500) & (ddf.y_location > 500) & (ddf.y_location < 1500),:] +ddf = ddf[ddf.segger_cell_id != "None"] +ddf = ddf.sort_values("segger_cell_id") +df = ddf.loc[(ddf.x_location > 250) & (ddf.x_location < 1500) & (ddf.y_location > 500) & (ddf.y_location < 1500), :] # tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') # ddf = tx_df.merge(df_main, on='transcript_id') seg2explorer( - seg_df = df, - source_path = 'data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs', - output_dir = 'data_tidy/explorer/rep1sis', - cells_filename = "segger_cells_seg_roi1", - analysis_filename = "segger_analysis_seg_roi1", - xenium_filename = "segger_experiment_seg_roi1.xenium", - analysis_df = None, - cell_id_columns = "segger_cell_id", - area_low = 10, - area_high= 1000, + seg_df=df, + source_path="data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs", + output_dir="data_tidy/explorer/rep1sis", + cells_filename="segger_cells_seg_roi1", + analysis_filename="segger_analysis_seg_roi1", + xenium_filename="segger_experiment_seg_roi1.xenium", + analysis_df=None, + cell_id_columns="segger_cell_id", + area_low=10, + area_high=1000, ) - -df = ddf.loc[(ddf.x_location > 1550) & (ddf.x_location < 3250) & (ddf.y_location > 2250) & (ddf.y_location < 3550),:] +df = ddf.loc[(ddf.x_location > 1550) & (ddf.x_location < 3250) & (ddf.y_location > 2250) & (ddf.y_location < 3550), :] # tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') # ddf = tx_df.merge(df_main, on='transcript_id') seg2explorer( - seg_df = df, - source_path = 'data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs', - output_dir = 'data_tidy/explorer/rep1sis', - cells_filename = "segger_cells_seg_roi2", - analysis_filename = "segger_analysis_seg_roi2", - xenium_filename = "segger_experiment_seg_roi2.xenium", - analysis_df = None, - cell_id_columns = "segger_cell_id", - area_low = 10, - area_high= 1000, + seg_df=df, + source_path="data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs", + output_dir="data_tidy/explorer/rep1sis", + cells_filename="segger_cells_seg_roi2", + analysis_filename="segger_analysis_seg_roi2", + xenium_filename="segger_experiment_seg_roi2.xenium", + analysis_df=None, + cell_id_columns="segger_cell_id", + area_low=10, + area_high=1000, ) - -df = ddf.loc[(ddf.x_location > 4000) & (ddf.x_location < 4500) & (ddf.y_location > 1000) & (ddf.y_location < 1500),:] +df = ddf.loc[(ddf.x_location > 4000) & (ddf.x_location < 4500) & (ddf.y_location > 1000) & (ddf.y_location < 1500), :] # tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') # ddf = tx_df.merge(df_main, on='transcript_id') seg2explorer( - seg_df = df, - source_path = 'data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs', - output_dir = 'data_tidy/explorer/rep1sis', - cells_filename = "segger_cells_seg_roi3", - analysis_filename = "segger_analysis_seg_roi3", - xenium_filename = "segger_experiment_seg_roi3.xenium", - analysis_df = None, - cell_id_columns = "segger_cell_id", - area_low = 10, - area_high= 1000, + seg_df=df, + source_path="data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs", + output_dir="data_tidy/explorer/rep1sis", + cells_filename="segger_cells_seg_roi3", + analysis_filename="segger_analysis_seg_roi3", + xenium_filename="segger_experiment_seg_roi3.xenium", + analysis_df=None, + cell_id_columns="segger_cell_id", + area_low=10, + area_high=1000, ) -df = ddf.loc[(ddf.x_location > 1550) & (ddf.x_location < 3250) & (ddf.y_location > 2250) & (ddf.y_location < 3550),:] +df = ddf.loc[(ddf.x_location > 1550) & (ddf.x_location < 3250) & (ddf.y_location > 2250) & (ddf.y_location < 3550), :] # tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') # ddf = tx_df.merge(df_main, on='transcript_id') seg2explorer( - seg_df = df, - source_path = 'data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs', - output_dir = 'data_tidy/explorer/rep1sis', - cells_filename = "segger_cells_seg_roi2", - analysis_filename = "segger_analysis_seg_roi2", - xenium_filename = "segger_experiment_seg_roi2.xenium", - analysis_df = None, - cell_id_columns = "segger_cell_id", - area_low = 10, - area_high= 1000, + seg_df=df, + source_path="data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs", + output_dir="data_tidy/explorer/rep1sis", + cells_filename="segger_cells_seg_roi2", + analysis_filename="segger_analysis_seg_roi2", + xenium_filename="segger_experiment_seg_roi2.xenium", + analysis_df=None, + cell_id_columns="segger_cell_id", + area_low=10, + area_high=1000, ) - diff --git a/scripts/train_model_sample.py b/scripts/train_model_sample.py index 411cee8..826989f 100644 --- a/scripts/train_model_sample.py +++ b/scripts/train_model_sample.py @@ -15,7 +15,7 @@ import os -segger_data_dir = segger_data_dir = Path('data_tidy/pyg_datasets/bc_fast_data_emb_minor') +segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/bc_fast_data_emb_minor") models_dir = Path("./models/bc_fast_data_emb_minor") dm = SeggerDataModule( diff --git a/src/segger/data/Untitled-4.py b/src/segger/data/Untitled-4.py index 51b3034..dea3858 100644 --- a/src/segger/data/Untitled-4.py +++ b/src/segger/data/Untitled-4.py @@ -1,38 +1,34 @@ -df_path = Path('/omics/groups/OE0606/internal/gleb/xenium/elyas/rep1/baysor_rep1.csv') +df_path = Path("/omics/groups/OE0606/internal/gleb/xenium/elyas/rep1/baysor_rep1.csv") df = dd.read_csv(df_path) -res = compute_cells_per_nucleus(df, new_cell_col='baysor_cell_id') +res = compute_cells_per_nucleus(df, new_cell_col="baysor_cell_id") plot_gradient_nuclei_histogram(res, figures_path) - - - -df_path = Path('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep1_v9_segger.csv.gz') +df_path = Path("data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep1_v9_segger.csv.gz") df_v9 = dd.read_csv(df_path) -df_main = dd.read_parquet('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/transcripts.parquet') +df_main = dd.read_parquet("data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/transcripts.parquet") -ddf = df_v9.merge(df_main, on='transcript_id') +ddf = df_v9.merge(df_main, on="transcript_id") # df = dd.read_parquet(segger_emb_path/'segger_transcripts.parquet') ddf1 = ddf.compute() -res = compute_cells_per_nucleus(df.compute(), new_cell_col='segger_cell_id') -plot_gradient_nuclei_histogram(res, figures_path, 'segger') +res = compute_cells_per_nucleus(df.compute(), new_cell_col="segger_cell_id") +plot_gradient_nuclei_histogram(res, figures_path, "segger") dfff = df.compute() dfff = dfff.dropna() gene_dropout = calculate_gene_dropout_rate(dfff) -plot_z_distance_distribution(dfff, figures_path, title='Z-Distance Distribution') +plot_z_distance_distribution(dfff, figures_path, title="Z-Distance Distribution") -plot_z_distance_boxplot(dfff, figures_path, title='Z-Distance Distribution') +plot_z_distance_boxplot(dfff, figures_path, title="Z-Distance Distribution") -tx_df = dd.read_csv('data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz') -ddf = tx_df.merge(df_main, on='transcript_id') - +tx_df = dd.read_csv("data_tidy/Xenium_FFPE_Human_Breast_Cancer_Rep2_v9_segger.csv.gz") +ddf = tx_df.merge(df_main, on="transcript_id") from segger.prediction.boundary import generate_boundary @@ -44,14 +40,12 @@ from pqdm.processes import pqdm # or use pqdm.threads for threading-based parallelism + # Modify the function to work with a single group to use with pqdm def process_group(group): cell_id, t = group - return { - "cell_id": cell_id, - "length": len(t), - "geom": generate_boundary(t, x="x_location", y="y_location") - } + return {"cell_id": cell_id, "length": len(t), "geom": generate_boundary(t, x="x_location", y="y_location")} + def generate_boundaries(df, x="x_location", y="y_location", cell_id="segger_cell_id", n_jobs=10): # Group by cell_id diff --git a/src/segger/data/parquet/_utils.py b/src/segger/data/parquet/_utils.py index 9010443..92791a8 100644 --- a/src/segger/data/parquet/_utils.py +++ b/src/segger/data/parquet/_utils.py @@ -55,6 +55,7 @@ def get_xy_extents( # If statistics are not available, compute them manually from the data except: import gc + print("metadata lacks the statistics of the tile's bounding box, computing might take longer!") parquet_file = pd.read_parquet(filepath) x_col = parquet_file.loc[:, x] diff --git a/src/segger/prediction/predict_parquet.py b/src/segger/prediction/predict_parquet.py index e4d499b..43dacf2 100644 --- a/src/segger/prediction/predict_parquet.py +++ b/src/segger/prediction/predict_parquet.py @@ -221,7 +221,7 @@ def get_similarity_scores( coords_1 = coords_2 = batch[to_type].pos else: coords_1 = batch[to_type].pos[:, :2] # 'tx' positions - coords_2 = batch[from_type].pos[:, :2] + coords_2 = batch[from_type].pos[:, :2] if knn_method == "kd_tree": # Compute edge indices using knn method (still on GPU) edge_index = get_edge_index( @@ -244,9 +244,9 @@ def get_similarity_scores( edge_index = coo_to_dense_adj(edge_index.T, num_nodes=shape[0], num_nbrs=receptive_field[f"k_{to_type}"]) with torch.no_grad(): - if from_type != to_type: + if from_type != to_type: embeddings = model(batch.x_dict, batch.edge_index_dict) - else: # to go with the inital embeddings for tx-tx + else: # to go with the inital embeddings for tx-tx embeddings = {key: model.node_init[key](x) for key, x in batch.x_dict.items()} norms = embeddings[to_type].norm(dim=1, keepdim=True) # Avoid division by zero in case there are zero vectors