Skip to content

Commit

Permalink
Merge pull request #21 from daniel-unyi-42/main
Browse files Browse the repository at this point in the history
Synchronize tile naming conventions
  • Loading branch information
EliHei2 authored Oct 7, 2024
2 parents 525ad3c + 228258b commit f9dff63
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
from pathlib import Path
import torch
import lightning as L
from pytorch_lightning import Trainer
from torch_geometric.loader import DataLoader
from segger.data.utils import SpatialTranscriptomicsDataset # Updated dataset class
from segger.models.segger_model import Segger
Expand Down Expand Up @@ -57,7 +57,7 @@ def main(args):
litsegger = LitSegger(model=model)

# Initialize the PyTorch Lightning trainer
trainer = L.Trainer(
trainer = Trainer(
accelerator=args.accelerator,
strategy=args.strategy,
precision=args.precision,
Expand Down
2 changes: 1 addition & 1 deletion src/segger/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def _process_tile(self, tile_params: Tuple) -> None:

# Save the tile data to the appropriate directory based on split
if self.verbose: print(f"Saving data for tile at (x_min: {x_loc}, y_min: {y_loc})...")
filename = f"tiles_x{x_loc}_y{y_loc}_{x_size}_{y_size}.pt"
filename = f"tiles_x={x_loc}_y={y_loc}_w={x_size}_h={y_size}.pt"
if prob > val_prob + test_prob:
torch.save(data, processed_dir / 'train_tiles' / 'processed' / filename)
elif prob > test_prob:
Expand Down
2 changes: 1 addition & 1 deletion src/segger/data/parquet/pyg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def processed_file_names(self) -> List[str]:
Returns:
List[str]: List of processed file names.
"""
paths = glob.glob(f'{self.processed_dir}/x=*_y=*_w=*_h=*.pt')
paths = glob.glob(f'{self.processed_dir}/tiles_x=*_y=*_w=*_h=*.pt')
file_names = list(map(os.path.basename, paths))
return file_names

Expand Down
2 changes: 1 addition & 1 deletion src/segger/data/parquet/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def uid(self) -> str:
'x=100_y=200_w=50_h=50'
"""
x_min, y_min, x_max, y_max = map(int, self.extents.bounds)
uid = f'x={x_min}_y={y_min}_w={x_max-x_min}_h={y_max-y_min}'
uid = f'tiles_x={x_min}_y={y_min}_w={x_max-x_min}_h={y_max-y_min}'
return uid


Expand Down

0 comments on commit f9dff63

Please sign in to comment.