diff --git a/scripts/predict_model_sample.py b/scripts/predict_model_sample.py index be2380f..074491d 100644 --- a/scripts/predict_model_sample.py +++ b/scripts/predict_model_sample.py @@ -1,7 +1,7 @@ from segger.data.io import XeniumSample from segger.training.train import LitSegger from segger.training.segger_data_module import SeggerDataModule -from segger.prediction.predict import predict, load_model +from segger.prediction.predict_gpu import predict, load_model from lightning.pytorch.loggers import CSVLogger from pytorch_lightning import Trainer from pathlib import Path @@ -14,6 +14,7 @@ import os import dask.dataframe as dd import pandas as pd +from pathlib import Path segger_data_dir = Path('./data_tidy/pyg_datasets/bc_embedding_0919')