Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster split_trains for long runs #459

Merged
merged 6 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions extra_data/keydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .exceptions import TrainIDError, NoDataError
from .file_access import FileAccess
from .read_machinery import (
contiguous_regions, DataChunk, select_train_ids, split_trains, roi_shape
contiguous_regions, DataChunk, select_train_ids, split_trains, roi_shape,
trains_files_index,
)

class KeyData:
Expand Down Expand Up @@ -185,13 +186,16 @@ def select_trains(self, trains):
def __getitem__(self, item):
return self.select_trains(item)

def _only_tids(self, tids):
def _only_tids(self, tids, files=None):
tids_arr = np.array(tids)
# Keep 1 file, even if 0 trains selected.
files = [
f for f in self.files
if f.has_train_ids(tids_arr, self.inc_suspect_trains)
] or [self.files[0]]
if files is None:
files = [
f for f in self.files
if f.has_train_ids(tids_arr, self.inc_suspect_trains)
]
if not files:
# Keep 1 file, even if 0 trains selected.
files = [self.files[0]]

return KeyData(
self.source,
Expand Down Expand Up @@ -232,8 +236,16 @@ def split_trains(self, parts=None, trains_per_part=None):
A maximum number of trains in each part. Parts will often have
fewer trains than this.
"""
for s in split_trains(len(self.train_ids), parts, trains_per_part):
yield self.select_trains(s)
# tids_files points to the file for each train.
# This avoids checking all files for each chunk, which can be slow.
tids_files = trains_files_index(
self.train_ids, self.files, self.inc_suspect_trains
)
for sl in split_trains(len(self.train_ids), parts, trains_per_part):
tids = self.train_ids[sl]
files = set(tids_files[sl]) - {None}
files = sorted(files, key=lambda f: f.filename)
yield self._only_tids(tids, files=files)

def data_counts(self, labelled=True):
"""Get a count of data entries in each train.
Expand Down
12 changes: 12 additions & 0 deletions extra_data/read_machinery.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ def split_trains(n_trains, parts=None, trains_per_part=None) -> [slice]:
for i in range(n_parts)
]

def trains_files_index(train_ids, files, inc_suspect_trains=True) -> list:
"""Make a list of which FileAccess contains each train, used in splitting"""
tids_files = [None] * len(train_ids)
tid_to_ix = {t: i for i, t in enumerate(train_ids)}
for file in files:
f_tids = file.train_ids if inc_suspect_trains else file.valid_train_ids
for tid in f_tids:
ix = tid_to_ix.get(tid, None)
if ix is not None:
tids_files[ix] = file
return tids_files

class DataChunk:
"""Reference to a contiguous chunk of data for one or more trains."""
def __init__(self, file, dataset_path, first, train_ids, counts):
Expand Down
27 changes: 23 additions & 4 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,13 +1013,12 @@ def select(self, seln_or_source_glob, key_glob='*', require_all=False,
else: # require_any
train_ids = np.union1d(train_ids, source_tids)

train_ids = list(train_ids) # Convert back to a list.
sources_data = {
src: srcdata._only_tids(train_ids)
for src, srcdata in sources_data.items()
}

train_ids = list(train_ids) # Convert back to a list.

else:
train_ids = self.train_ids

Expand Down Expand Up @@ -1123,8 +1122,28 @@ def split_trains(self, parts=None, trains_per_part=None):
A maximum number of trains in each part. Parts will often have
fewer trains than this.
"""
for s in split_trains(len(self.train_ids), parts, trains_per_part):
yield self.select_trains(s)
for source in self._sources_data.values():
assert source.train_ids == self.train_ids

def dict_zip(iter_d):
while True:
try:
yield {k: next(v) for (k, v) in iter_d.items()}
except StopIteration:
return

for sources_data_part in dict_zip({
n: s.split_trains(parts=parts, trains_per_part=trains_per_part)
for (n, s) in self._sources_data.items()
}):
files = set().union(*[sd.files for sd in sources_data_part.values()])
train_ids = list(sources_data_part.values())[0].train_ids

yield DataCollection(
files, sources_data=sources_data_part, train_ids=train_ids,
aliases=self._aliases, inc_suspect_trains=self.inc_suspect_trains,
is_single_run=self.is_single_run,
)

def _check_source_conflicts(self):
"""Check for data with the same source and train ID in different files.
Expand Down
33 changes: 24 additions & 9 deletions extra_data/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from .exceptions import MultiRunError, PropertyNameError, NoDataError
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import glob_wildcards_re, same_run, select_train_ids, split_trains
from .read_machinery import (
glob_wildcards_re, same_run, select_train_ids, split_trains, trains_files_index
)


class SourceData:
Expand Down Expand Up @@ -262,16 +264,21 @@ def select_trains(self, trains) -> 'SourceData':
"""
return self._only_tids(select_train_ids(self.train_ids, trains))

def _only_tids(self, tids) -> 'SourceData':
def _only_tids(self, tids, files=None) -> 'SourceData':
if files is None:
files = [
f for f in self.files
if f.has_train_ids(tids, self.inc_suspect_trains)
]
if not files:
# Keep 1 file, even if 0 trains selected, to get keys, dtypes, etc.
files = [self.files[0]]

return SourceData(
self.source,
sel_keys=self.sel_keys,
train_ids=tids,
# Keep 1 file, even if 0 trains selected, to get keys, dtypes, etc.
files=[
f for f in self.files
if f.has_train_ids(tids, self.inc_suspect_trains)
] or [self.files[0]],
files=files,
section=self.section,
is_single_run=self.is_single_run,
inc_suspect_trains=self.inc_suspect_trains
Expand Down Expand Up @@ -309,8 +316,16 @@ def split_trains(self, parts=None, trains_per_part=None):
A maximum number of trains in each part. Parts will often have
fewer trains than this.
"""
for s in split_trains(len(self.train_ids), parts, trains_per_part):
yield self.select_trains(s)
# tids_files points to the file for each train.
# This avoids checking all files for each chunk, which can be slow.
tids_files = trains_files_index(
self.train_ids, self.files, self.inc_suspect_trains
)
for sl in split_trains(len(self.train_ids), parts, trains_per_part):
tids = self.train_ids[sl]
files = set(tids_files[sl]) - {None}
files = sorted(files, key=lambda f: f.filename)
yield self._only_tids(tids, files=files)

def data_counts(self, labelled=True, index_group=None):
"""Get a count of data entries in each train.
Expand Down
2 changes: 2 additions & 0 deletions extra_data/tests/test_reader_mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,8 @@ def test_split_trains(mock_fxe_raw_run):
chunks = list(run.split_trains(3))
assert len(chunks) == 3
assert {len(c.train_ids) for c in chunks} == {160}
arr = chunks[0]['FXE_XAD_GEC/CAM/CAMERA:daqOutput', 'data.image.dims'].ndarray()
assert arr.shape == (160, 2)

chunks = list(run.split_trains(4, trains_per_part=100))
assert len(chunks) == 5
Expand Down
2 changes: 2 additions & 0 deletions extra_data/tests/test_sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def test_split_trains(mock_spb_raw_run):
chunks = list(xgm.split_trains(3))
assert len(chunks) == 3
assert {len(c.train_ids) for c in chunks} == {21, 22}
# The middle chunk spans across 2 files
assert [len(c.files) for c in chunks] == [1, 2, 1]

chunks = list(xgm.split_trains(3, trains_per_part=20))
assert len(chunks) == 4
Expand Down