diff --git a/docs/reading_files.rst b/docs/reading_files.rst index 31b9f639..bf549cef 100644 --- a/docs/reading_files.rst +++ b/docs/reading_files.rst @@ -268,6 +268,16 @@ methods offer extra capabilities. Getting data by train --------------------- +Selecting trains in a run, source or key data returns a new object with a subset +of the data matching the train selection. For example, You can do:: + + # run data + run = run[:100] # first 100 trains in the run + # source data + source = source[by_id[12345678]] # data for train ID == 1234568 + # key data + key = key[np.s_[10:20]] # data for the 10th to the 20th trains + Some kinds of data, e.g. from AGIPD, are too big to load a whole run into memory at once. In these cases, it's convenient to load one train at a time. diff --git a/extra_data/reader.py b/extra_data/reader.py index f38fc230..e18b128d 100644 --- a/extra_data/reader.py +++ b/extra_data/reader.py @@ -33,17 +33,17 @@ import numpy as np from . import locality, voview +from .aliases import AliasIndexer from .exceptions import (MultiRunError, PropertyNameError, SourceNameError, TrainIDError) from .file_access import FileAccess from .keydata import KeyData -from .read_machinery import (DETECTOR_SOURCE_RE, FilenameInfo, by_id, by_index, - find_proposal, glob_wildcards_re, same_run, - select_train_ids, split_trains) +from .read_machinery import (DETECTOR_SOURCE_RE, by_id, by_index, + find_proposal, glob_wildcards_re, is_int_like, + same_run, select_train_ids) from .run_files_map import RunFilesMap from .sourcedata import SourceData from .utils import available_cpu_cores -from .aliases import AliasIndexer __all__ = [ 'H5File', @@ -274,12 +274,21 @@ def _get_source_data(self, source): return self._sources_data[source] def __getitem__(self, item): - if isinstance(item, tuple) and len(item) == 2: + if ( + isinstance(item, tuple) and + len(item) == 2 and + all(isinstance(e, str) for e in item) + ): return self._get_key_data(*item) elif isinstance(item, str): return self._get_source_data(item) + elif ( + isinstance(item, (by_id, by_index, list, np.ndarray, slice)) or + is_int_like(item) + ): + return self.select_trains(item) - raise TypeError("Expected data[source] or data[source, key]") + raise TypeError("Expected data[source], data[source, key] or data[train_selection]") def _ipython_key_completions_(self): return list(self.all_sources) diff --git a/extra_data/sourcedata.py b/extra_data/sourcedata.py index 5897cdc0..215a2142 100644 --- a/extra_data/sourcedata.py +++ b/extra_data/sourcedata.py @@ -2,15 +2,15 @@ import re from typing import Dict, List, Optional -import numpy as np import h5py +import numpy as np -from .exceptions import MultiRunError, PropertyNameError, NoDataError +from .exceptions import MultiRunError, NoDataError, PropertyNameError from .file_access import FileAccess from .keydata import KeyData -from .read_machinery import ( - glob_wildcards_re, same_run, select_train_ids, split_trains, trains_files_index -) +from .read_machinery import (by_id, by_index, glob_wildcards_re, is_int_like, + same_run, select_train_ids, split_trains, + trains_files_index) class SourceData: @@ -67,6 +67,12 @@ def __contains__(self, key): return res def __getitem__(self, key): + if ( + isinstance(key, (by_id, by_index, list, np.ndarray, slice)) or + is_int_like(key) + ): + return self.select_trains(key) + if key not in self: raise PropertyNameError(key, self.source) ds0 = self.files[0].file[ diff --git a/extra_data/tests/test_reader_mockdata.py b/extra_data/tests/test_reader_mockdata.py index b608a6ed..fa536d51 100644 --- a/extra_data/tests/test_reader_mockdata.py +++ b/extra_data/tests/test_reader_mockdata.py @@ -607,6 +607,10 @@ def test_select(mock_fxe_raw_run): assert sel_by_kd.control_sources == {kd.source} assert sel_by_kd.keys_for_source(kd.source) == {kd.key} + # disallow mixing source and train ID selection + with pytest.raises(TypeError): + run['SPB_XTD9_XGM/DOOCS/MAIN', 10] + @pytest.mark.parametrize( 'select_str', @@ -717,6 +721,12 @@ def test_select_trains(mock_fxe_raw_run): with pytest.raises(IndexError): run.select_trains(by_index[[480]]) + assert run[10].train_ids == [10010] + assert run[by_id[10000]].train_ids == [10000] + assert run[by_index[479:555]].train_ids == [10479] + with pytest.raises(IndexError): + run[555] + def test_split_trains(mock_fxe_raw_run): run = RunDirectory(mock_fxe_raw_run) diff --git a/extra_data/tests/test_sourcedata.py b/extra_data/tests/test_sourcedata.py index e1183c6e..a237eedc 100644 --- a/extra_data/tests/test_sourcedata.py +++ b/extra_data/tests/test_sourcedata.py @@ -114,6 +114,19 @@ def test_select_trains(mock_spb_raw_run): assert sel.train_ids == [] assert sel.keys() == xgm.keys() + sel = xgm[by_id[10020:10040]] + assert sel.train_ids == list(range(10020, 10040)) + + sel = xgm[by_index[:10]] + assert sel.train_ids == list(range(10000, 10010)) + + sel = xgm[10] + assert sel.train_ids == [10010] + + sel = xgm[999:1000] + assert sel.train_ids == [] + assert sel.keys() == xgm.keys() + def test_split_trains(mock_spb_raw_run): run = RunDirectory(mock_spb_raw_run)