Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neuromodulator #869

Merged
merged 7 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ibllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import warnings

__version__ = '2.39.1'
__version__ = '2.39.2'
warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib')

# if this becomes a full-blown library we should let the logging configuration to the discretion of the dev
Expand Down
30 changes: 16 additions & 14 deletions ibllib/io/extractors/bpod_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
This module will extract the Bpod trials and wheel data based on the task protocol,
i.e. habituation, training or biased.
"""
import logging
import importlib

from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor
from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor, BaseExtractor
from ibllib.io.extractors.habituation_trials import HabituationTrials
from ibllib.io.extractors.training_trials import TrainingTrials
from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials
from ibllib.io.extractors.base import BaseBpodTrialsExtractor

_logger = logging.getLogger(__name__)


def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor:
"""
Expand All @@ -39,20 +36,25 @@ def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavio
'BiasedTrials': BiasedTrials,
'EphysTrials': EphysTrials
}

if protocol:
class_name = protocol2extractor(protocol)
extractor_class_name = protocol2extractor(protocol)
else:
class_name = get_bpod_extractor_class(session_path, task_collection=task_collection)
if class_name in builtins:
return builtins[class_name](session_path)
extractor_class_name = get_bpod_extractor_class(session_path, task_collection=task_collection)
if extractor_class_name in builtins:
return builtins[extractor_class_name](session_path)

# look if there are custom extractor types in the personal projects repo
if not class_name.startswith('projects.'):
class_name = 'projects.' + class_name
module, class_name = class_name.rsplit('.', 1)
if not extractor_class_name.startswith('projects.'):
extractor_class_name = 'projects.' + extractor_class_name
module, extractor_class_name = extractor_class_name.rsplit('.', 1)
mdl = importlib.import_module(module)
extractor_class = getattr(mdl, class_name, None)
extractor_class = getattr(mdl, extractor_class_name, None)
if extractor_class:
return extractor_class(session_path)
my_extractor = extractor_class(session_path)
if not isinstance(my_extractor, BaseExtractor):
raise ValueError(
f"{my_extractor} should be an Extractor class inheriting from ibllib.io.extractors.base.BaseExtractor")
return my_extractor
else:
raise ValueError(f'extractor {class_name} not found')
raise ValueError(f'extractor {extractor_class_name} not found')
35 changes: 32 additions & 3 deletions ibllib/tests/extractors/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import shutil
import tempfile
import unittest
import unittest.mock
from unittest.mock import patch, Mock, MagicMock
from pathlib import Path

import numpy as np
import pandas as pd

import one.alf.io as alfio
from ibllib.io.extractors.bpod_trials import get_bpod_extractor
from ibllib.io.extractors import training_trials, biased_trials, camera
from ibllib.io import raw_data_loaders as raw
from ibllib.io.extractors.base import BaseExtractor
Expand Down Expand Up @@ -530,13 +531,13 @@ def test_size_outputs(self):
'peakVelocity_times': np.array([1, 1])}
function_name = 'ibllib.io.extractors.training_wheel.extract_wheel_moves'
# Training
with unittest.mock.patch(function_name, return_value=mock_data):
with patch(function_name, return_value=mock_data):
task, = get_trials_tasks(self.training_lt5['path'])
trials, _ = task.extract_behaviour(save=True)
trials = alfio.load_object(self.training_lt5['path'] / 'alf', object='trials')
self.assertTrue(alfio.check_dimensions(trials) == 0)
# Biased
with unittest.mock.patch(function_name, return_value=mock_data):
with patch(function_name, return_value=mock_data):
task, = get_trials_tasks(self.biased_lt5['path'])
trials, _ = task.extract_behaviour(save=True)
trials = alfio.load_object(self.biased_lt5['path'] / 'alf', object='trials')
Expand Down Expand Up @@ -753,5 +754,33 @@ def test_attribute_times(self, display=False):
camera.attribute_times(tsa, tsb, injective=False, take='closest')


class TestGetBpodExtractor(unittest.TestCase):

def test_get_bpod_extractor(self):
# un-existing extractor should raise a value error
with self.assertRaises(ValueError):
get_bpod_extractor('', protocol='sdf', task_collection='raw_behavior_data')
# in this case this returns an ibllib.io.extractors.training_trials.TrainingTrials instance
extractor = get_bpod_extractor(
'', protocol='_trainingChoiceWorld',
task_collection='raw_behavior_data'
)
self.assertTrue(isinstance(extractor, BaseExtractor))

def test_get_bpod_custom_extractor(self):
# here we'll mock a custom module with a custom extractor
DummyModule = MagicMock()
DummyExtractor = Mock(spec_set=BaseExtractor)
DummyModule.toto.return_value = DummyExtractor
base_module = 'ibllib.io.extractors.bpod_trials'
with patch(f'{base_module}.get_bpod_extractor_class', return_value='toto'), \
patch(f'{base_module}.importlib.import_module', return_value=DummyModule) as import_mock:
self.assertIs(get_bpod_extractor(''), DummyExtractor)
import_mock.assert_called_with('projects')
# Check raises when imported class not an extractor
DummyModule.toto.return_value = MagicMock(spec=dict)
self.assertRaisesRegex(ValueError, 'should be an Extractor class', get_bpod_extractor, '')


if __name__ == '__main__':
unittest.main(exit=False, verbosity=2)
4 changes: 4 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

#### 2.39.1
- Bugfix: brainbox.metrics.single_unit.quick_unit_metrics fix for indexing of n_spike_below2
-
#### 2.39.2
- Bugfix: routing of protocol to extractor through the project repository checks that the
target is indeed an extractor class.

## Release Note 2.38.0

Expand Down
Loading