Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selecting train IDs in DataCollection and SourceData #559

Merged
merged 6 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/reading_files.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,16 @@ methods offer extra capabilities.
Getting data by train
---------------------

Selecting trains in a run, source or key data returns a new object with a subset
of the data matching the train selection. For example, You can do::

# run data
run = run[:100] # first 100 trains in the run
# source data
source = source[by_id[12345678]] # data for train ID == 1234568
# key data
key = key[np.s_[10:20]] # data for the 10th to the 20th trains

Some kinds of data, e.g. from AGIPD, are too big to load a whole run into
memory at once. In these cases, it's convenient to load one train at a time.

Expand Down
21 changes: 15 additions & 6 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@
import numpy as np

from . import locality, voview
from .aliases import AliasIndexer
from .exceptions import (MultiRunError, PropertyNameError, SourceNameError,
TrainIDError)
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import (DETECTOR_SOURCE_RE, FilenameInfo, by_id, by_index,
find_proposal, glob_wildcards_re, same_run,
select_train_ids, split_trains)
from .read_machinery import (DETECTOR_SOURCE_RE, by_id, by_index,
find_proposal, glob_wildcards_re, is_int_like,
same_run, select_train_ids)
from .run_files_map import RunFilesMap
from .sourcedata import SourceData
from .utils import available_cpu_cores
from .aliases import AliasIndexer

__all__ = [
'H5File',
Expand Down Expand Up @@ -274,12 +274,21 @@ def _get_source_data(self, source):
return self._sources_data[source]

def __getitem__(self, item):
if isinstance(item, tuple) and len(item) == 2:
if (
isinstance(item, tuple) and
len(item) == 2 and
all(isinstance(e, str) for e in item)
):
return self._get_key_data(*item)
elif isinstance(item, str):
return self._get_source_data(item)
elif (
isinstance(item, (by_id, by_index, list, np.ndarray, slice)) or
is_int_like(item)
):
return self.select_trains(item)

raise TypeError("Expected data[source] or data[source, key]")
raise TypeError("Expected data[source], data[source, key] or data[train_selection]")
tmichela marked this conversation as resolved.
Show resolved Hide resolved

def _ipython_key_completions_(self):
return list(self.all_sources)
Expand Down
16 changes: 11 additions & 5 deletions extra_data/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import re
from typing import Dict, List, Optional

import numpy as np
import h5py
import numpy as np

from .exceptions import MultiRunError, PropertyNameError, NoDataError
from .exceptions import MultiRunError, NoDataError, PropertyNameError
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import (
glob_wildcards_re, same_run, select_train_ids, split_trains, trains_files_index
)
from .read_machinery import (by_id, by_index, glob_wildcards_re, is_int_like,
same_run, select_train_ids, split_trains,
trains_files_index)


class SourceData:
Expand Down Expand Up @@ -67,6 +67,12 @@ def __contains__(self, key):
return res

def __getitem__(self, key):
if (
isinstance(key, (by_id, by_index, list, np.ndarray, slice)) or
is_int_like(key)
):
return self.select_trains(key)

if key not in self:
raise PropertyNameError(key, self.source)
ds0 = self.files[0].file[
Expand Down
10 changes: 10 additions & 0 deletions extra_data/tests/test_reader_mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ def test_select(mock_fxe_raw_run):
assert sel_by_kd.control_sources == {kd.source}
assert sel_by_kd.keys_for_source(kd.source) == {kd.key}

# disallow mixing source and train ID selection
with pytest.raises(TypeError):
run['SPB_XTD9_XGM/DOOCS/MAIN', 10]


@pytest.mark.parametrize(
'select_str',
Expand Down Expand Up @@ -717,6 +721,12 @@ def test_select_trains(mock_fxe_raw_run):
with pytest.raises(IndexError):
run.select_trains(by_index[[480]])

assert run[10].train_ids == [10010]
assert run[by_id[10000]].train_ids == [10000]
assert run[by_index[479:555]].train_ids == [10479]
with pytest.raises(IndexError):
run[555]
Dismissed Show dismissed Hide dismissed


def test_split_trains(mock_fxe_raw_run):
run = RunDirectory(mock_fxe_raw_run)
Expand Down
13 changes: 13 additions & 0 deletions extra_data/tests/test_sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def test_select_trains(mock_spb_raw_run):
assert sel.train_ids == []
assert sel.keys() == xgm.keys()

sel = xgm[by_id[10020:10040]]
assert sel.train_ids == list(range(10020, 10040))

sel = xgm[by_index[:10]]
assert sel.train_ids == list(range(10000, 10010))

sel = xgm[10]
assert sel.train_ids == [10010]
tmichela marked this conversation as resolved.
Show resolved Hide resolved

sel = xgm[999:1000]
assert sel.train_ids == []
assert sel.keys() == xgm.keys()


def test_split_trains(mock_spb_raw_run):
run = RunDirectory(mock_spb_raw_run)
Expand Down