diff --git a/scripts/config.yaml b/scripts/config.yaml index a2121e3..38b74e0 100644 --- a/scripts/config.yaml +++ b/scripts/config.yaml @@ -62,7 +62,7 @@ prediction: min_transcripts: 5 cell_id_col: segger_cell_id use_cc: false - knn_method: "cuda" + knn_method: "kd_tree" file_format: "anndata" k_bd: 4 dist_bd: 15.0 diff --git a/src/segger/data/utils.py b/src/segger/data/utils.py index 7c2d83e..201eff3 100644 --- a/src/segger/data/utils.py +++ b/src/segger/data/utils.py @@ -314,6 +314,10 @@ def get_edge_index_kdtree( Returns: torch.Tensor: Edge indices. """ + if isinstance(coords_1, torch.Tensor): + coords_1 = coords_1.cpu().numpy() + if isinstance(coords_2, torch.Tensor): + coords_2 = coords_2.cpu().numpy() tree = cKDTree(coords_1) d_kdtree, idx_out = tree.query(coords_2, k=k, distance_upper_bound=dist, workers=workers) valid_mask = d_kdtree < dist diff --git a/src/segger/models/segger_model.py b/src/segger/models/segger_model.py index ea3f6c6..f470574 100644 --- a/src/segger/models/segger_model.py +++ b/src/segger/models/segger_model.py @@ -70,19 +70,19 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim) # First layer x = x.relu() - x = self.conv_first(x, edge_index) + self.lin_first(x) + x = self.conv_first(x, edge_index) + self.lin_first(x) x = x.relu() # Middle layers if self.num_mid_layers > 0: for i in range(self.num_mid_layers): conv_mid = self.conv_mid_layers[i] - lin_mid = self.lin_mid_layers[i] - x = conv_mid(x, edge_index) + lin_mid(x) + lin_mid = self.lin_mid_layers[i] + x = conv_mid(x, edge_index) + lin_mid(x) x = x.relu() # Last layer - x = self.conv_last(x, edge_index) + self.lin_last(x) + x = self.conv_last(x, edge_index) + self.lin_last(x) return x diff --git a/src/segger/prediction/predict.py b/src/segger/prediction/predict.py index 31eac9d..da8f7cd 100644 --- a/src/segger/prediction/predict.py +++ b/src/segger/prediction/predict.py @@ -271,7 +271,9 @@ def _get_id(): row_cpu = scores_tx.row.get() # Transfer row indices to CPU (NumPy) col_cpu = scores_tx.col.get() # Transfer column indices to CPU (NumPy) # Remove from memory - scores_tx = get_similarity_scores(lit_segger.model, batch, "tx", "tx", receptive_field) + scores_tx = get_similarity_scores( + lit_segger.model, batch, "tx", "tx", receptive_field, knn_method=knn_method + ) # Convert to dense NumPy array data_cpu = scores_tx.data.get() # Transfer data to CPU (NumPy) row_cpu = scores_tx.row.get() # Transfer row indices to CPU (NumPy) diff --git a/src/segger/prediction/predict_parquet.py b/src/segger/prediction/predict_parquet.py index 58a76dd..1d8824b 100644 --- a/src/segger/prediction/predict_parquet.py +++ b/src/segger/prediction/predict_parquet.py @@ -344,7 +344,14 @@ def _get_id(): # Step 3: Handle unassigned transcripts with connected components (if use_cc=True) if use_cc: scores_tx = get_similarity_scores( - lit_segger.model, batch, "tx", "tx", receptive_field, compute_sigmoid = False, knn_method=knn_method, gpu_id=gpu_id + lit_segger.model, + batch, + "tx", + "tx", + receptive_field, + compute_sigmoid=False, + knn_method=knn_method, + gpu_id=gpu_id, ) # Stay on GPU and use CuPy sparse matrices @@ -399,7 +406,7 @@ def _get_id(): # Step 4: Convert assignments to Dask-CuDF DataFrame for this batch # batch_ddf = dask_cudf.from_cudf(cudf.DataFrame(assignments), npartitions=1) assignments = pd.DataFrame(assignments) - assignments = assignments[assignments['bound'] == 1] + assignments = assignments[assignments["bound"] == 1] batch_ddf = delayed(dd.from_pandas)(assignments, npartitions=1) # Save the updated `output_ddf` asynchronously using Dask delayed @@ -518,57 +525,57 @@ def segment( seg_final_dd = pd.read_parquet(output_ddf_save_path) seg_final_dd = seg_final_dd.set_index("transcript_id") - + step_start_time = time() if verbose: print(f"Applying max score selection logic...") - + # Step 1: Find max bound indices (bound == 1) and max unbound indices (bound == 0) max_bound_idx = seg_final_dd[seg_final_dd["bound"] == 1].groupby("transcript_id")["score"].idxmax() max_unbound_idx = seg_final_dd[seg_final_dd["bound"] == 0].groupby("transcript_id")["score"].idxmax() - + # Step 2: Combine indices, prioritizing bound=1 scores final_idx = max_bound_idx.combine_first(max_unbound_idx) - + # Step 3: Use the computed final_idx to select the best assignments # Make sure you are using the divisions and set the index correctly before loc seg_final_filtered = seg_final_dd.loc[final_idx] - + if verbose: elapsed_time = time() - step_start_time print(f"Max score selection completed in {elapsed_time:.2f} seconds.") - + # Step 3: Load the transcripts DataFrame and merge results - + if verbose: print(f"Loading transcripts from {transcript_file}...") - + transcripts_df = pd.read_parquet(transcript_file) transcripts_df["transcript_id"] = transcripts_df["transcript_id"].astype(str) - + step_start_time = time() if verbose: print(f"Merging segmentation results with transcripts...") - + # Outer merge to include all transcripts, even those without assigned cell ids transcripts_df_filtered = transcripts_df.merge(seg_final_filtered, on="transcript_id", how="outer") - + if verbose: elapsed_time = time() - step_start_time print(f"Merged segmentation results with transcripts in {elapsed_time:.2f} seconds.") - + step_start_time = time() if verbose: print(f"Computing connected components for unassigned transcripts...") # Load edge indices from saved Parquet edge_index_dd = pd.read_parquet(edge_index_save_path) - + # Step 2: Get unique transcript_ids from edge_index_dd and their positional indices transcript_ids_in_edges = pd.concat([edge_index_dd["source"], edge_index_dd["target"]]).unique() - + # Create a lookup table with unique indices lookup_table = pd.Series(data=range(len(transcript_ids_in_edges)), index=transcript_ids_in_edges).to_dict() - + # Map source and target to positional indices edge_index_dd["index_source"] = edge_index_dd["source"].map(lookup_table) edge_index_dd["index_target"] = edge_index_dd["target"].map(lookup_table) @@ -576,44 +583,44 @@ def segment( source_indices = np.asarray(edge_index_dd["index_source"]) target_indices = np.asarray(edge_index_dd["index_target"]) data_cp = np.ones(len(source_indices), dtype=np.float32) - + # Create the sparse COO matrix coo_cp_matrix = scipy_coo_matrix( (data_cp, (source_indices, target_indices)), shape=(len(transcript_ids_in_edges), len(transcript_ids_in_edges)), ) - + # Use CuPy's connected components algorithm to compute components n, comps = cc(coo_cp_matrix, directed=True, connection="strong") if verbose: elapsed_time = time() - step_start_time print(f"Computed connected components for unassigned transcripts in {elapsed_time:.2f} seconds.") - + step_start_time = time() if verbose: print(f"The rest...") # # Step 4: Map back the component labels to the original transcript_ids - + def _get_id(): """Generate a random Xenium-style ID.""" return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx" - + new_ids = np.array([_get_id() for _ in range(n)]) comp_labels = new_ids[comps] comp_labels = pd.Series(comp_labels, index=transcript_ids_in_edges) # Step 5: Handle only unassigned transcripts in transcripts_df_filtered unassigned_mask = transcripts_df_filtered["segger_cell_id"].isna() - + unassigned_transcripts_df = transcripts_df_filtered.loc[unassigned_mask, ["transcript_id"]] - + # Step 6: Map component labels only to unassigned transcript_ids new_segger_cell_ids = unassigned_transcripts_df["transcript_id"].map(comp_labels) - + # Step 7: Create a DataFrame with updated 'segger_cell_id' for unassigned transcripts unassigned_transcripts_df = unassigned_transcripts_df.assign(segger_cell_id=new_segger_cell_ids) - + # Step 8: Merge this DataFrame back into the original to update only the unassigned segger_cell_id - + # Merging the updates back to the original DataFrame transcripts_df_filtered = transcripts_df_filtered.merge( unassigned_transcripts_df[["transcript_id", "segger_cell_id"]], @@ -621,14 +628,14 @@ def _get_id(): how="left", # Perform a left join to only update the unassigned rows suffixes=("", "_new"), # Suffix for new column to avoid overwriting ) - + # Step 9: Fill missing segger_cell_id values with the updated values from the merge transcripts_df_filtered["segger_cell_id"] = transcripts_df_filtered["segger_cell_id"].fillna( transcripts_df_filtered["segger_cell_id_new"] ) transcripts_df_filtered = transcripts_df_filtered.drop(columns=["segger_cell_id_new"]) - + if verbose: elapsed_time = time() - step_start_time print(f"The rest computed in {elapsed_time:.2f} seconds.")