Skip to content

Commit

Permalink
fixed discrepancy between old and new data versions
Browse files Browse the repository at this point in the history
  • Loading branch information
EliHei2 committed Oct 8, 2024
1 parent e1a7b05 commit b2adc96
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions scripts/predict_model_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
model,
dm,
save_dir=benchmarks_dir,
seg_tag='segger_embedding_1001_0.5',
seg_tag='segger_embedding_1001_0.5_cc',
transcript_file=transcripts_file,
file_format='anndata',
receptive_field = receptive_field,
min_transcripts=5,
# max_transcripts=1500,
cell_id_col='segger_cell_id',
use_cc=False,
use_cc=True,
knn_method='cuda'
)
3 changes: 2 additions & 1 deletion src/segger/data/parquet/pyg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def processed_file_names(self) -> List[str]:
Returns:
List[str]: List of processed file names.
"""
paths = glob.glob(f'{self.processed_dir}/tiles_x=*_y=*_w=*_h=*.pt')
paths = glob.glob(f'{self.processed_dir}/tiles_x*_y*_*_*.pt')
# paths = paths.append(paths = glob.glob(f'{self.processed_dir}/tiles_x*_y*_*_*.pt'))
file_names = list(map(os.path.basename, paths))
return file_names

Expand Down
4 changes: 2 additions & 2 deletions src/segger/prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def sparse_multiply(embeddings, edge_index, shape) -> coo_matrix:
indices[1] = edge_index[edge_index != -1]
rows = cp.fromDlpack(to_dlpack(indices[0,:].to('cuda')))
columns = cp.fromDlpack(to_dlpack(indices[1,:].to('cuda')))
print(rows)
# print(rows)
del indices
values = similarity[edge_index != -1].flatten()
sparse_result = coo_matrix((cp.fromDlpack(to_dlpack(values)), (rows, columns)), shape=shape)
Expand Down Expand Up @@ -419,7 +419,7 @@ def segment(
seg_combined = pd.concat([segmentation_train, segmentation_val, segmentation_test], ignore_index=True)

# seg_combined = segmentation_test
print(seg_combined.columns)
# print(seg_combined.columns)
# print(transcripts_df.id)
# Drop any unassigned rows
seg_final = seg_combined.dropna(subset=['segger_cell_id']).reset_index(drop=True)
Expand Down

0 comments on commit b2adc96

Please sign in to comment.