Skip to content

Commit

Permalink
Selecting train IDs in DataCollection and SourceData (#559)
Browse files Browse the repository at this point in the history
accept train id slicing with the getitem syntax for datacollections and sourcedata objects
  • Loading branch information
tmichela authored Oct 23, 2024
1 parent 6fec557 commit fd2370c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 11 deletions.
10 changes: 10 additions & 0 deletions docs/reading_files.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,16 @@ methods offer extra capabilities.
Getting data by train
---------------------

Selecting trains in a run, source or key data returns a new object with a subset
of the data matching the train selection. For example, You can do::

# run data
run = run[:100] # first 100 trains in the run
# source data
source = source[by_id[12345678]] # data for train ID == 1234568
# key data
key = key[np.s_[10:20]] # data for the 10th to the 20th trains

Some kinds of data, e.g. from AGIPD, are too big to load a whole run into
memory at once. In these cases, it's convenient to load one train at a time.

Expand Down
21 changes: 15 additions & 6 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@
import numpy as np

from . import locality, voview
from .aliases import AliasIndexer
from .exceptions import (MultiRunError, PropertyNameError, SourceNameError,
TrainIDError)
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import (DETECTOR_SOURCE_RE, FilenameInfo, by_id, by_index,
find_proposal, glob_wildcards_re, same_run,
select_train_ids, split_trains)
from .read_machinery import (DETECTOR_SOURCE_RE, by_id, by_index,
find_proposal, glob_wildcards_re, is_int_like,
same_run, select_train_ids)
from .run_files_map import RunFilesMap
from .sourcedata import SourceData
from .utils import available_cpu_cores
from .aliases import AliasIndexer

__all__ = [
'H5File',
Expand Down Expand Up @@ -274,12 +274,21 @@ def _get_source_data(self, source):
return self._sources_data[source]

def __getitem__(self, item):
if isinstance(item, tuple) and len(item) == 2:
if (
isinstance(item, tuple) and
len(item) == 2 and
all(isinstance(e, str) for e in item)
):
return self._get_key_data(*item)
elif isinstance(item, str):
return self._get_source_data(item)
elif (
isinstance(item, (by_id, by_index, list, np.ndarray, slice)) or
is_int_like(item)
):
return self.select_trains(item)

raise TypeError("Expected data[source] or data[source, key]")
raise TypeError("Expected data[source], data[source, key] or data[train_selection]")

def _ipython_key_completions_(self):
return list(self.all_sources)
Expand Down
16 changes: 11 additions & 5 deletions extra_data/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import re
from typing import Dict, List, Optional

import numpy as np
import h5py
import numpy as np

from .exceptions import MultiRunError, PropertyNameError, NoDataError
from .exceptions import MultiRunError, NoDataError, PropertyNameError
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import (
glob_wildcards_re, same_run, select_train_ids, split_trains, trains_files_index
)
from .read_machinery import (by_id, by_index, glob_wildcards_re, is_int_like,
same_run, select_train_ids, split_trains,
trains_files_index)


class SourceData:
Expand Down Expand Up @@ -67,6 +67,12 @@ def __contains__(self, key):
return res

def __getitem__(self, key):
if (
isinstance(key, (by_id, by_index, list, np.ndarray, slice)) or
is_int_like(key)
):
return self.select_trains(key)

if key not in self:
raise PropertyNameError(key, self.source)
ds0 = self.files[0].file[
Expand Down
10 changes: 10 additions & 0 deletions extra_data/tests/test_reader_mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ def test_select(mock_fxe_raw_run):
assert sel_by_kd.control_sources == {kd.source}
assert sel_by_kd.keys_for_source(kd.source) == {kd.key}

# disallow mixing source and train ID selection
with pytest.raises(TypeError):
run['SPB_XTD9_XGM/DOOCS/MAIN', 10]


@pytest.mark.parametrize(
'select_str',
Expand Down Expand Up @@ -721,6 +725,12 @@ def test_select_trains(mock_fxe_raw_run):
with pytest.raises(IndexError):
run.select_trains(by_index[[480]])

assert run[10].train_ids == [10010]
assert run[by_id[10000]].train_ids == [10000]
assert run[by_index[479:555]].train_ids == [10479]
with pytest.raises(IndexError):
run[555]


def test_split_trains(mock_fxe_raw_run):
run = RunDirectory(mock_fxe_raw_run)
Expand Down
13 changes: 13 additions & 0 deletions extra_data/tests/test_sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def test_select_trains(mock_spb_raw_run):
assert sel.train_ids == []
assert sel.keys() == xgm.keys()

sel = xgm[by_id[10020:10040]]
assert sel.train_ids == list(range(10020, 10040))

sel = xgm[by_index[:10]]
assert sel.train_ids == list(range(10000, 10010))

sel = xgm[10]
assert sel.train_ids == [10010]

sel = xgm[999:1000]
assert sel.train_ids == []
assert sel.keys() == xgm.keys()


def test_split_trains(mock_spb_raw_run):
run = RunDirectory(mock_spb_raw_run)
Expand Down

0 comments on commit fd2370c

Please sign in to comment.