diff --git a/tests/test_train.py b/tests/test_train.py index 4fe83969..ee437114 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -14,7 +14,7 @@ pl.seed_everything(100) -def create_data(site_id: int) -> Data: +def create_data() -> Data: num_channels = 2 num_time = 12 height = 10 @@ -50,8 +50,8 @@ def test_train(): data_path = ( ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt' ) - batch_data = create_data(i) - joblib.dump(batch_data, str(data_path), compress=5) + batch_data = create_data() + batch_data.to_file(data_path) dataset = EdgeDataset( ppaths.train_path,