Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concept of legacy source names #527

Merged
merged 10 commits into from
Jun 24, 2024
31 changes: 23 additions & 8 deletions extra_data/file_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(self, filename, _cache_info=None):
self.train_ids = _cache_info['train_ids']
self.control_sources = _cache_info['control_sources']
self.instrument_sources = _cache_info['instrument_sources']
self.legacy_sources = _cache_info.get('legacy_sources', {})
self.validity_flag = _cache_info.get('flag', None)
else:
try:
Expand All @@ -155,7 +156,8 @@ def __init__(self, filename, _cache_info=None):

self.train_ids = tid_data[tid_data != 0]

self.control_sources, self.instrument_sources = self._read_data_sources()
self.control_sources, self.instrument_sources, self.legacy_sources \
= self._read_data_sources()

self.validity_flag = None

Expand Down Expand Up @@ -294,7 +296,7 @@ def format_version(self):
return self._format_version

def _read_data_sources(self):
control_sources, instrument_sources = set(), set()
control_sources, instrument_sources, legacy_sources = set(), set(), dict()

# The list of data sources moved in file format 1.0
if self.format_version == '0.5':
Expand All @@ -307,16 +309,28 @@ def _read_data_sources(self):
except KeyError:
raise FileStructureError(f'{data_sources_path} not found')

for source in data_sources_group[:]:
if not source:
for source_id in data_sources_group[:]:
if not source_id:
continue
source = source.decode()
category, _, h5_source = source.partition('/')
source_id = source_id.decode()
category, _, h5_source = source_id.partition('/')
if category == 'INSTRUMENT':
device, _, chan_grp = h5_source.partition(':')
chan, _, group = chan_grp.partition('/')
source = device + ':' + chan
instrument_sources.add(source)

if source not in instrument_sources:
# The legacy source name is only expected to be used
# for instrument (more precisely, XTDF sources) for
# now. For performance reasons, the check is
# therefore only performed here, and only once rather
# than by index group.
item = self.file.get(f'{category}/{source}', getlink=True)

if isinstance(item, h5py.SoftLink):
legacy_sources[source] = item.path[1:].partition('/')[2]

instrument_sources.add(source)
# TODO: Do something with groups?
elif category == 'CONTROL':
control_sources.add(h5_source)
Expand All @@ -327,7 +341,8 @@ def _read_data_sources(self):
else:
raise ValueError("Unknown data category %r" % category)

return frozenset(control_sources), frozenset(instrument_sources)
return frozenset(control_sources), frozenset(instrument_sources), \
legacy_sources

def _guess_valid_trains(self):
# File format version 1.0 includes a flag which is 0 if a train ID
Expand Down
2 changes: 1 addition & 1 deletion extra_data/read_machinery.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
log = logging.getLogger(__name__)

DETECTOR_NAMES = {'AGIPD', 'DSSC', 'LPD'}
DETECTOR_SOURCE_RE = re.compile(r'(.+)/DET/(\d+)CH')
DETECTOR_SOURCE_RE = re.compile(r'(.+\/(?:DET|CORR))\/(\d+)CH')

DATA_ROOT_DIR = os.environ.get('EXTRA_DATA_DATA_ROOT', '/gpfs/exfel/exp')

Expand Down
88 changes: 64 additions & 24 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,22 @@ def __init__(

if sources_data is None:
files_by_sources = defaultdict(list)
legacy_sources = dict()
for f in self.files:
for source in f.control_sources:
files_by_sources[source, 'CONTROL'].append(f)
for source in f.instrument_sources:
files_by_sources[source, 'INSTRUMENT'].append(f)
legacy_sources.update(f.legacy_sources)
sources_data = {
src: SourceData(src,
sel_keys=None,
train_ids=train_ids,
files=files,
section=section,
canonical_name=legacy_sources.get(src, src),
is_single_run=self.is_single_run,
inc_suspect_trains=self.inc_suspect_trains,
inc_suspect_trains=self.inc_suspect_trains
)
for ((src, section), files) in files_by_sources.items()
}
Expand All @@ -127,6 +130,10 @@ def __init__(
name for (name, sd) in self._sources_data.items()
if sd.section == 'INSTRUMENT'
})
self.legacy_sources = {
name: sd.canonical_name for (name, sd)
in self._sources_data.items() if sd.is_legacy
}

@staticmethod
def _open_file(path, cache_info=None):
Expand Down Expand Up @@ -223,7 +230,8 @@ def all_sources(self):

@property
def detector_sources(self):
return set(filter(DETECTOR_SOURCE_RE.match, self.instrument_sources))
return set(filter(DETECTOR_SOURCE_RE.match, self.instrument_sources)) \
- self.legacy_sources.keys()

def _check_field(self, source, key):
if source not in self.all_sources:
Expand Down Expand Up @@ -255,6 +263,14 @@ def _get_source_data(self, source):
if source not in self._sources_data:
raise SourceNameError(source)

sd = self._sources_data[source]

if sd.is_legacy:
warn(f"{source} is a legacy name for {self.legacy_sources[source]}. "
f"Access via this name will be removed for future data.",
DeprecationWarning,
stacklevel=3)

return self._sources_data[source]

def __getitem__(self, item):
Expand Down Expand Up @@ -283,7 +299,8 @@ def _check_data_missing(self, tid) -> bool:
if file is None:
return True

for source in self.instrument_sources:
# No need to evaluate this for legacy sources as well.
for source in self.instrument_sources - self.legacy_sources.keys():
file, pos = self._find_data(source, tid)
if file is None:
return True
Expand Down Expand Up @@ -1249,30 +1266,30 @@ def info(self, details_for_sources=()):
# Include summary section for multi-module detectors unless
# source details are enabled.

detector_modules = {}
sources_by_detector = {}
for source in self.detector_sources:
name, modno = DETECTOR_SOURCE_RE.match(source).groups((1, 2))
detector_modules[(name, modno)] = source
sources_by_detector.setdefault(name, {})[modno] = source

# A run should only have one detector, but if that changes, don't hide it
detector_name = ','.join(sorted(set(k[0] for k in detector_modules)))
for detector_name in sorted(sources_by_detector.keys()):
detector_modules = sources_by_detector[detector_name]

print("{} XTDF detector modules ({})".format(
len(self.detector_sources), detector_name
))
if len(detector_modules) > 0:
# Show detail on the first module (the others should be similar)
mod_key = sorted(detector_modules)[0]
mod_source = detector_modules[mod_key]
dinfo = self.detector_info(mod_source)
module = ' '.join(mod_key)
dims = ' x '.join(str(d) for d in dinfo['dims'])
print(" e.g. module {} : {} pixels".format(module, dims))
print(" {}".format(mod_source))
print(" {} frames per train, up to {} frames total".format(
dinfo['frames_per_train'], dinfo['total_frames']
print("{} XTDF detector modules of {}/*".format(
len(detector_modules), detector_name
))
print()
if len(detector_modules) > 0:
# Show detail on the first module (the others should be similar)
mod_key = sorted(detector_modules)[0]
mod_source = detector_modules[mod_key]
dinfo = self.detector_info(mod_source)
module = ' '.join(mod_key)
dims = ' x '.join(str(d) for d in dinfo['dims'])
print(" e.g. module {} : {} pixels".format(module, dims))
print(" {}".format(mod_source))
print(" {} frames per train, up to {} frames total".format(
dinfo['frames_per_train'], dinfo['total_frames']
))
print()

# Invert aliases for faster lookup.
src_aliases = defaultdict(set)
Expand Down Expand Up @@ -1321,11 +1338,11 @@ def keys_detail(s, keys, prefix=''):

if details_for_sources:
# All instrument sources with details enabled.
displayed_inst_srcs = self.instrument_sources
displayed_inst_srcs = self.instrument_sources - self.legacy_sources.keys()
print(len(displayed_inst_srcs), 'instrument sources:')
else:
# Only non-XTDF instrument sources without details enabled.
displayed_inst_srcs = self.instrument_sources - self.detector_sources
displayed_inst_srcs = self.instrument_sources - self.detector_sources - self.legacy_sources.keys()
print(len(displayed_inst_srcs), 'instrument sources (excluding XTDF detectors):')

for s in sorted(displayed_inst_srcs):
Expand Down Expand Up @@ -1377,6 +1394,29 @@ def keys_detail(s, keys, prefix=''):

print()

if self.legacy_sources:
# Collect legacy souces matching DETECTOR_SOURCE_RE
# separately for a condensed view.
detector_legacy_sources = defaultdict(set)

print(len(self.legacy_sources), 'legacy source names:')
for s in sorted(self.legacy_sources.keys()):
m = DETECTOR_SOURCE_RE.match(s)

if m is not None:
detector_legacy_sources[m[1]].add(s)
else:
# Only print non-XTDF legacy sources.
print(' -', s, '->', self.legacy_sources[s])

for legacy_det, legacy_sources in detector_legacy_sources.items():
canonical_mod = self.legacy_sources[next(iter(legacy_sources))]
canonical_det = DETECTOR_SOURCE_RE.match(canonical_mod)[1]

print(' -', f'{legacy_det}/*', '->', f'{canonical_det}/*',
f'({len(legacy_sources)})')
print()

def plot_missing_data(self, min_saved_pct=95, expand_instrument=False):
"""Plot sources that have missing data for some trains.

Expand Down
15 changes: 11 additions & 4 deletions extra_data/run_files_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@ def get(self, path):
res = {
'train_ids': np.array(d['train_ids'], dtype=np.uint64),
'control_sources': frozenset(d['control_sources']),
'instrument_sources': frozenset(d['instrument_sources'])
'instrument_sources': frozenset(d['instrument_sources']),
}
# Older cache files don't contain info on legacy sources.
if 'legacy_sources' in d:
res['legacy_sources'] = d['legacy_sources']
# Older cache files don't contain info on 'suspect' trains.
if 'suspect_train_indices' in d:
res['flag'] = flag = np.ones_like(d['train_ids'], dtype=np.bool_)
Expand All @@ -155,9 +158,11 @@ def get(self, path):
def _cache_valid(self, fname):
# The cache is invalid (needs to be written out) if the file is not in
# files_data (which it won't be if the size or mtime don't match - see
# load()), or if suspect_train_indices is missing. This was added after
# we started making cache files, so we want to add it to existing caches.
return 'suspect_train_indices' in self.files_data.get(fname, {})
# load()), or if the later added suspect_train_indices/legagy_sources
# are missing. These may be missing from caches created by legacy
# versions of EXtra-data.
return not bool({'suspect_train_indices', 'legacy_sources'} \
- self.files_data.get(fname, {}).keys())

def save(self, files):
"""Save the cache if needed
Expand Down Expand Up @@ -192,6 +197,8 @@ def save(self, files):
'train_ids': [int(t) for t in file_access.train_ids],
'control_sources': sorted(file_access.control_sources),
'instrument_sources': sorted(file_access.instrument_sources),
'legacy_sources': {k: file_access.legacy_sources[k]
for k in sorted(file_access.legacy_sources)},
'suspect_train_indices': [
int(i) for i in (~file_access.validity_flag).nonzero()[0]
],
Expand Down
11 changes: 10 additions & 1 deletion extra_data/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ class SourceData:

def __init__(
self, source, *, sel_keys, train_ids, files, section,
is_single_run, inc_suspect_trains=True
canonical_name, is_single_run, inc_suspect_trains=True,
):
self.source = source
self.sel_keys = sel_keys
self.train_ids = train_ids
self.files: List[FileAccess] = files
self.section = section
self.canonical_name = canonical_name
self.is_single_run = is_single_run
self.inc_suspect_trains = inc_suspect_trains

Expand All @@ -47,6 +48,11 @@ def is_instrument(self):
"""Whether this source is an instrument source."""
return self.section == 'INSTRUMENT'

@property
def is_legacy(self):
"""Whether this source is a legacy name for another source."""
return self.canonical_name != self.source

def _has_exact_key(self, key):
if self.sel_keys is not None:
return key in self.sel_keys
Expand Down Expand Up @@ -258,6 +264,7 @@ def select_keys(self, keys) -> 'SourceData':
train_ids=self.train_ids,
files=self.files,
section=self.section,
canonical_name=self.canonical_name,
is_single_run=self.is_single_run,
inc_suspect_trains=self.inc_suspect_trains
)
Expand All @@ -283,6 +290,7 @@ def _only_tids(self, tids, files=None) -> 'SourceData':
train_ids=tids,
files=files,
section=self.section,
canonical_name=self.canonical_name,
is_single_run=self.is_single_run,
inc_suspect_trains=self.inc_suspect_trains
)
Expand Down Expand Up @@ -481,6 +489,7 @@ def union(self, *others) -> 'SourceData':
train_ids=sorted(train_ids),
files=sorted(files, key=lambda f: f.filename),
section=self.section,
canonical_name=self.canonical_name,
is_single_run=same_run(self, *others),
inc_suspect_trains=self.inc_suspect_trains
)
8 changes: 8 additions & 0 deletions extra_data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def mock_spb_raw_and_proc_run():
yield td, raw_run_dir, proc_run_dir


@pytest.fixture(scope='session')
def mock_modern_spb_proc_run(format_version):
with TemporaryDirectory() as td:
make_examples.make_modern_spb_proc_run(
td, format_version=format_version)
yield td


@pytest.fixture(scope='session')
def mock_jungfrau_run():
with TemporaryDirectory() as td:
Expand Down
10 changes: 10 additions & 0 deletions extra_data/tests/make_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ def make_reduced_spb_run(dir_path, raw=True, rng=None, format_version='0.5'):
format_version=format_version)


def make_modern_spb_proc_run(dir_path, format_version='0.5'):
for modno in range(16):
path = osp.join(dir_path, f'CORR-R0142-AGIPD{modno:0>2}-S00000.h5')
write_file(path, [
AGIPDModule(f'SPB_DET_AGIPD1M-1/CORR/{modno}CH0', raw=False,
frames_per_train=32,
legacy_name=f'SPB_DET_AGIPD1M-1/DET/{modno}CH0')
], ntrains=64, chunksize=32, format_version=format_version)


def make_agipd1m_run(
dir_path,
rep_rate=True,
Expand Down
Loading
Loading