Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selecting train IDs in DataCollection and SourceData #559

Merged
merged 6 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@
import numpy as np

from . import locality, voview
from .aliases import AliasIndexer
from .exceptions import (MultiRunError, PropertyNameError, SourceNameError,
TrainIDError)
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import (DETECTOR_SOURCE_RE, FilenameInfo, by_id, by_index,
find_proposal, glob_wildcards_re, same_run,
select_train_ids, split_trains)
from .read_machinery import (DETECTOR_SOURCE_RE, by_id, by_index,
find_proposal, glob_wildcards_re, is_int_like,
same_run, select_train_ids)
from .run_files_map import RunFilesMap
from .sourcedata import SourceData
from .utils import available_cpu_cores
from .aliases import AliasIndexer

__all__ = [
'H5File',
Expand Down Expand Up @@ -278,8 +278,13 @@ def __getitem__(self, item):
return self._get_key_data(*item)
elif isinstance(item, str):
return self._get_source_data(item)
elif (
isinstance(item, (by_id, by_index, list, np.ndarray, slice)) or
is_int_like(item)
):
return self.select_trains(item)

raise TypeError("Expected data[source] or data[source, key]")
raise TypeError("Expected data[source], data[source, key] or data[train_selection]")
tmichela marked this conversation as resolved.
Show resolved Hide resolved

def _ipython_key_completions_(self):
return list(self.all_sources)
Expand Down
16 changes: 11 additions & 5 deletions extra_data/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import re
from typing import Dict, List, Optional

import numpy as np
import h5py
import numpy as np

from .exceptions import MultiRunError, PropertyNameError, NoDataError
from .exceptions import MultiRunError, NoDataError, PropertyNameError
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import (
glob_wildcards_re, same_run, select_train_ids, split_trains, trains_files_index
)
from .read_machinery import (by_id, by_index, glob_wildcards_re, is_int_like,
same_run, select_train_ids, split_trains,
trains_files_index)


class SourceData:
Expand Down Expand Up @@ -67,6 +67,12 @@ def __contains__(self, key):
return res

def __getitem__(self, key):
if (
isinstance(key, (by_id, by_index, list, np.ndarray, slice)) or
is_int_like(key)
):
return self.select_trains(key)

if key not in self:
raise PropertyNameError(key, self.source)
ds0 = self.files[0].file[
Expand Down
6 changes: 6 additions & 0 deletions extra_data/tests/test_reader_mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,12 @@
with pytest.raises(IndexError):
run.select_trains(by_index[[480]])

assert run[10].train_ids == [10010]
assert run[by_id[10000]].train_ids == [10000]
assert run[by_index[479:555]].train_ids == [10479]
with pytest.raises(IndexError):
run[555]
Dismissed Show dismissed Hide dismissed


def test_split_trains(mock_fxe_raw_run):
run = RunDirectory(mock_fxe_raw_run)
Expand Down
13 changes: 13 additions & 0 deletions extra_data/tests/test_sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def test_select_trains(mock_spb_raw_run):
assert sel.train_ids == []
assert sel.keys() == xgm.keys()

sel = xgm[by_id[10020:10040]]
assert sel.train_ids == list(range(10020, 10040))

sel = xgm[by_index[:10]]
assert sel.train_ids == list(range(10000, 10010))

sel = xgm[10]
assert sel.train_ids == [10010]
tmichela marked this conversation as resolved.
Show resolved Hide resolved

sel = xgm[999:1000]
assert sel.train_ids == []
assert sel.keys() == xgm.keys()


def test_split_trains(mock_spb_raw_run):
run = RunDirectory(mock_spb_raw_run)
Expand Down