diff --git a/scripts/train_model.py b/scripts/train_model.py index 4e1e361..7c25bc3 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -3,7 +3,7 @@ import argparse from pathlib import Path import torch -import lightning as L +from pytorch_lightning import Trainer from torch_geometric.loader import DataLoader from segger.data.utils import SpatialTranscriptomicsDataset # Updated dataset class from segger.models.segger_model import Segger @@ -57,7 +57,7 @@ def main(args): litsegger = LitSegger(model=model) # Initialize the PyTorch Lightning trainer - trainer = L.Trainer( + trainer = Trainer( accelerator=args.accelerator, strategy=args.strategy, precision=args.precision, diff --git a/src/segger/data/io.py b/src/segger/data/io.py index 3d45afd..1f00431 100644 --- a/src/segger/data/io.py +++ b/src/segger/data/io.py @@ -815,7 +815,7 @@ def _process_tile(self, tile_params: Tuple) -> None: # Save the tile data to the appropriate directory based on split if self.verbose: print(f"Saving data for tile at (x_min: {x_loc}, y_min: {y_loc})...") - filename = f"tiles_x{x_loc}_y{y_loc}_{x_size}_{y_size}.pt" + filename = f"tiles_x={x_loc}_y={y_loc}_w={x_size}_h={y_size}.pt" if prob > val_prob + test_prob: torch.save(data, processed_dir / 'train_tiles' / 'processed' / filename) elif prob > test_prob: diff --git a/src/segger/data/parquet/pyg_dataset.py b/src/segger/data/parquet/pyg_dataset.py index a464e72..3d3e117 100644 --- a/src/segger/data/parquet/pyg_dataset.py +++ b/src/segger/data/parquet/pyg_dataset.py @@ -37,7 +37,7 @@ def processed_file_names(self) -> List[str]: Returns: List[str]: List of processed file names. """ - paths = glob.glob(f'{self.processed_dir}/x=*_y=*_w=*_h=*.pt') + paths = glob.glob(f'{self.processed_dir}/tiles_x=*_y=*_w=*_h=*.pt') file_names = list(map(os.path.basename, paths)) return file_names diff --git a/src/segger/data/parquet/sample.py b/src/segger/data/parquet/sample.py index 970f17f..9407ced 100644 --- a/src/segger/data/parquet/sample.py +++ b/src/segger/data/parquet/sample.py @@ -785,7 +785,7 @@ def uid(self) -> str: 'x=100_y=200_w=50_h=50' """ x_min, y_min, x_max, y_max = map(int, self.extents.bounds) - uid = f'x={x_min}_y={y_min}_w={x_max-x_min}_h={y_max-y_min}' + uid = f'tiles_x={x_min}_y={y_min}_w={x_max-x_min}_h={y_max-y_min}' return uid