Skip to content

Commit

Permalink
Merge pull request #67 from daniel-unyi-42/main
Browse files Browse the repository at this point in the history
Change default knn method to kd-tree
  • Loading branch information
daniel-unyi-42 authored Dec 9, 2024
2 parents fb58fc1 + de7fe72 commit 8da9bc8
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 35 deletions.
2 changes: 1 addition & 1 deletion scripts/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ prediction:
min_transcripts: 5
cell_id_col: segger_cell_id
use_cc: false
knn_method: "cuda"
knn_method: "kd_tree"
file_format: "anndata"
k_bd: 4
dist_bd: 15.0
Expand Down
4 changes: 4 additions & 0 deletions src/segger/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ def get_edge_index_kdtree(
Returns:
torch.Tensor: Edge indices.
"""
if isinstance(coords_1, torch.Tensor):
coords_1 = coords_1.cpu().numpy()
if isinstance(coords_2, torch.Tensor):
coords_2 = coords_2.cpu().numpy()
tree = cKDTree(coords_1)
d_kdtree, idx_out = tree.query(coords_2, k=k, distance_upper_bound=dist, workers=workers)
valid_mask = d_kdtree < dist
Expand Down
8 changes: 4 additions & 4 deletions src/segger/models/segger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,19 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
# First layer
x = x.relu()
x = self.conv_first(x, edge_index) + self.lin_first(x)
x = self.conv_first(x, edge_index) + self.lin_first(x)
x = x.relu()

# Middle layers
if self.num_mid_layers > 0:
for i in range(self.num_mid_layers):
conv_mid = self.conv_mid_layers[i]
lin_mid = self.lin_mid_layers[i]
x = conv_mid(x, edge_index) + lin_mid(x)
lin_mid = self.lin_mid_layers[i]
x = conv_mid(x, edge_index) + lin_mid(x)
x = x.relu()

# Last layer
x = self.conv_last(x, edge_index) + self.lin_last(x)
x = self.conv_last(x, edge_index) + self.lin_last(x)

return x

Expand Down
4 changes: 3 additions & 1 deletion src/segger/prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ def _get_id():
row_cpu = scores_tx.row.get() # Transfer row indices to CPU (NumPy)
col_cpu = scores_tx.col.get() # Transfer column indices to CPU (NumPy)
# Remove from memory
scores_tx = get_similarity_scores(lit_segger.model, batch, "tx", "tx", receptive_field)
scores_tx = get_similarity_scores(
lit_segger.model, batch, "tx", "tx", receptive_field, knn_method=knn_method
)
# Convert to dense NumPy array
data_cpu = scores_tx.data.get() # Transfer data to CPU (NumPy)
row_cpu = scores_tx.row.get() # Transfer row indices to CPU (NumPy)
Expand Down
65 changes: 36 additions & 29 deletions src/segger/prediction/predict_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,14 @@ def _get_id():
# Step 3: Handle unassigned transcripts with connected components (if use_cc=True)
if use_cc:
scores_tx = get_similarity_scores(
lit_segger.model, batch, "tx", "tx", receptive_field, compute_sigmoid = False, knn_method=knn_method, gpu_id=gpu_id
lit_segger.model,
batch,
"tx",
"tx",
receptive_field,
compute_sigmoid=False,
knn_method=knn_method,
gpu_id=gpu_id,
)

# Stay on GPU and use CuPy sparse matrices
Expand Down Expand Up @@ -399,7 +406,7 @@ def _get_id():
# Step 4: Convert assignments to Dask-CuDF DataFrame for this batch
# batch_ddf = dask_cudf.from_cudf(cudf.DataFrame(assignments), npartitions=1)
assignments = pd.DataFrame(assignments)
assignments = assignments[assignments['bound'] == 1]
assignments = assignments[assignments["bound"] == 1]
batch_ddf = delayed(dd.from_pandas)(assignments, npartitions=1)

# Save the updated `output_ddf` asynchronously using Dask delayed
Expand Down Expand Up @@ -518,117 +525,117 @@ def segment(

seg_final_dd = pd.read_parquet(output_ddf_save_path)
seg_final_dd = seg_final_dd.set_index("transcript_id")

step_start_time = time()
if verbose:
print(f"Applying max score selection logic...")

# Step 1: Find max bound indices (bound == 1) and max unbound indices (bound == 0)
max_bound_idx = seg_final_dd[seg_final_dd["bound"] == 1].groupby("transcript_id")["score"].idxmax()
max_unbound_idx = seg_final_dd[seg_final_dd["bound"] == 0].groupby("transcript_id")["score"].idxmax()

# Step 2: Combine indices, prioritizing bound=1 scores
final_idx = max_bound_idx.combine_first(max_unbound_idx)

# Step 3: Use the computed final_idx to select the best assignments
# Make sure you are using the divisions and set the index correctly before loc
seg_final_filtered = seg_final_dd.loc[final_idx]

if verbose:
elapsed_time = time() - step_start_time
print(f"Max score selection completed in {elapsed_time:.2f} seconds.")

# Step 3: Load the transcripts DataFrame and merge results

if verbose:
print(f"Loading transcripts from {transcript_file}...")

transcripts_df = pd.read_parquet(transcript_file)
transcripts_df["transcript_id"] = transcripts_df["transcript_id"].astype(str)

step_start_time = time()
if verbose:
print(f"Merging segmentation results with transcripts...")

# Outer merge to include all transcripts, even those without assigned cell ids
transcripts_df_filtered = transcripts_df.merge(seg_final_filtered, on="transcript_id", how="outer")

if verbose:
elapsed_time = time() - step_start_time
print(f"Merged segmentation results with transcripts in {elapsed_time:.2f} seconds.")

step_start_time = time()
if verbose:
print(f"Computing connected components for unassigned transcripts...")
# Load edge indices from saved Parquet
edge_index_dd = pd.read_parquet(edge_index_save_path)

# Step 2: Get unique transcript_ids from edge_index_dd and their positional indices
transcript_ids_in_edges = pd.concat([edge_index_dd["source"], edge_index_dd["target"]]).unique()

# Create a lookup table with unique indices
lookup_table = pd.Series(data=range(len(transcript_ids_in_edges)), index=transcript_ids_in_edges).to_dict()

# Map source and target to positional indices
edge_index_dd["index_source"] = edge_index_dd["source"].map(lookup_table)
edge_index_dd["index_target"] = edge_index_dd["target"].map(lookup_table)
# Step 3: Compute connected components for transcripts involved in edges
source_indices = np.asarray(edge_index_dd["index_source"])
target_indices = np.asarray(edge_index_dd["index_target"])
data_cp = np.ones(len(source_indices), dtype=np.float32)

# Create the sparse COO matrix
coo_cp_matrix = scipy_coo_matrix(
(data_cp, (source_indices, target_indices)),
shape=(len(transcript_ids_in_edges), len(transcript_ids_in_edges)),
)

# Use CuPy's connected components algorithm to compute components
n, comps = cc(coo_cp_matrix, directed=True, connection="strong")
if verbose:
elapsed_time = time() - step_start_time
print(f"Computed connected components for unassigned transcripts in {elapsed_time:.2f} seconds.")

step_start_time = time()
if verbose:
print(f"The rest...")
# # Step 4: Map back the component labels to the original transcript_ids

def _get_id():
"""Generate a random Xenium-style ID."""
return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx"

new_ids = np.array([_get_id() for _ in range(n)])
comp_labels = new_ids[comps]
comp_labels = pd.Series(comp_labels, index=transcript_ids_in_edges)
# Step 5: Handle only unassigned transcripts in transcripts_df_filtered
unassigned_mask = transcripts_df_filtered["segger_cell_id"].isna()

unassigned_transcripts_df = transcripts_df_filtered.loc[unassigned_mask, ["transcript_id"]]

# Step 6: Map component labels only to unassigned transcript_ids
new_segger_cell_ids = unassigned_transcripts_df["transcript_id"].map(comp_labels)

# Step 7: Create a DataFrame with updated 'segger_cell_id' for unassigned transcripts
unassigned_transcripts_df = unassigned_transcripts_df.assign(segger_cell_id=new_segger_cell_ids)

# Step 8: Merge this DataFrame back into the original to update only the unassigned segger_cell_id

# Merging the updates back to the original DataFrame
transcripts_df_filtered = transcripts_df_filtered.merge(
unassigned_transcripts_df[["transcript_id", "segger_cell_id"]],
on="transcript_id",
how="left", # Perform a left join to only update the unassigned rows
suffixes=("", "_new"), # Suffix for new column to avoid overwriting
)

# Step 9: Fill missing segger_cell_id values with the updated values from the merge
transcripts_df_filtered["segger_cell_id"] = transcripts_df_filtered["segger_cell_id"].fillna(
transcripts_df_filtered["segger_cell_id_new"]
)

transcripts_df_filtered = transcripts_df_filtered.drop(columns=["segger_cell_id_new"])

if verbose:
elapsed_time = time() - step_start_time
print(f"The rest computed in {elapsed_time:.2f} seconds.")
Expand Down

0 comments on commit 8da9bc8

Please sign in to comment.