Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #207 from SpikeInterface/fix_raise_error_when_grou…
Browse files Browse the repository at this point in the history
…ping

Propagate raise_error to get_results
  • Loading branch information
alejoe91 authored Mar 11, 2021
2 parents d893112 + b73341a commit 12e3571
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
23 changes: 16 additions & 7 deletions spikesorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import time
import copy
from pathlib import Path
import os
import datetime
import json
import traceback
import shutil
import warnings
from joblib import Parallel, delayed

import numpy as np
Expand Down Expand Up @@ -225,18 +225,25 @@ def _run(self, recording, output_folder):
def get_result_from_folder(output_folder):
raise NotImplementedError

def get_result_list(self):
def get_result_list(self, raise_error=True):
sorting_list = []
for i, _ in enumerate(self.recording_list):
sorting = self.get_result_from_folder(self.output_folders[i])
sorting_list.append(sorting)
try:
sorting = self.get_result_from_folder(self.output_folders[i])
sorting_list.append(sorting)
except Exception as err:
if raise_error:
raise SpikeSortingError(f"Failed to load sorting output {i}")
else:
warnings.warn(f"Sorting output {i} could not be loaded")
return sorting_list

def get_result(self):
sorting_list = self.get_result_list()
def get_result(self, raise_error=True):
sorting_list = self.get_result_list(raise_error=raise_error)

if len(sorting_list) == 1:
sorting = sorting_list[0]
else:
elif len(sorting_list) > 1:
for i, sorting in enumerate(sorting_list):
property_name = self.recording_list[i].get_channel_property(self.recording_list[i].get_channel_ids()[0],
self.grouping_property)
Expand All @@ -248,6 +255,8 @@ def get_result(self):
sorting_list = [sort for sort in sorting_list if sort is not None]
multi_sorting = se.MultiSortingExtractor(sortings=sorting_list)
sorting = multi_sorting
else:
raise SpikeSortingError(f"None of the sorting outputs could be loaded")

if self.delete_folders:
for out in self.output_folders:
Expand Down
2 changes: 1 addition & 1 deletion spikesorters/sorterlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_outpu
verbose=verbose, delete_output_folder=delete_output_folder)
sorter.set_params(**params)
sorter.run(raise_error=raise_error, parallel=parallel, n_jobs=n_jobs, joblib_backend=joblib_backend)
sortingextractor = sorter.get_result()
sortingextractor = sorter.get_result(raise_error=raise_error)

return sortingextractor

Expand Down

0 comments on commit 12e3571

Please sign in to comment.