diff --git a/spikesorters/basesorter.py b/spikesorters/basesorter.py index 08c41de..388c15a 100644 --- a/spikesorters/basesorter.py +++ b/spikesorters/basesorter.py @@ -20,11 +20,11 @@ import time import copy from pathlib import Path -import os import datetime import json import traceback import shutil +import warnings from joblib import Parallel, delayed import numpy as np @@ -225,18 +225,25 @@ def _run(self, recording, output_folder): def get_result_from_folder(output_folder): raise NotImplementedError - def get_result_list(self): + def get_result_list(self, raise_error=True): sorting_list = [] for i, _ in enumerate(self.recording_list): - sorting = self.get_result_from_folder(self.output_folders[i]) - sorting_list.append(sorting) + try: + sorting = self.get_result_from_folder(self.output_folders[i]) + sorting_list.append(sorting) + except Exception as err: + if raise_error: + raise SpikeSortingError(f"Failed to load sorting output {i}") + else: + warnings.warn(f"Sorting output {i} could not be loaded") return sorting_list - def get_result(self): - sorting_list = self.get_result_list() + def get_result(self, raise_error=True): + sorting_list = self.get_result_list(raise_error=raise_error) + if len(sorting_list) == 1: sorting = sorting_list[0] - else: + elif len(sorting_list) > 1: for i, sorting in enumerate(sorting_list): property_name = self.recording_list[i].get_channel_property(self.recording_list[i].get_channel_ids()[0], self.grouping_property) @@ -248,6 +255,8 @@ def get_result(self): sorting_list = [sort for sort in sorting_list if sort is not None] multi_sorting = se.MultiSortingExtractor(sortings=sorting_list) sorting = multi_sorting + else: + raise SpikeSortingError(f"None of the sorting outputs could be loaded") if self.delete_folders: for out in self.output_folders: diff --git a/spikesorters/sorterlist.py b/spikesorters/sorterlist.py index f68e4bb..74d2e7b 100644 --- a/spikesorters/sorterlist.py +++ b/spikesorters/sorterlist.py @@ -91,7 +91,7 @@ def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_outpu verbose=verbose, delete_output_folder=delete_output_folder) sorter.set_params(**params) sorter.run(raise_error=raise_error, parallel=parallel, n_jobs=n_jobs, joblib_backend=joblib_backend) - sortingextractor = sorter.get_result() + sortingextractor = sorter.get_result(raise_error=raise_error) return sortingextractor