Skip to content

Commit

Permalink
Merge pull request #431 from OpenCOMPES/energy_calibration_performanc…
Browse files Browse the repository at this point in the history
…e_fix

faster version of per_file channels
  • Loading branch information
rettigl authored Jun 22, 2024
2 parents ffe2013 + 6854d39 commit 0004113
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 101 deletions.
51 changes: 9 additions & 42 deletions sed/core/dfops.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,53 +390,20 @@ def offset_by_other_columns(
"Please open a request on GitHub if this feature is required.",
)

# calculate the mean of the columns to reduce
means = {
col: dask.delayed(df[col].mean())
for col, red, pm in zip(offset_columns, reductions, preserve_mean)
if red or pm
}

# define the functions to apply the offsets
def shift_by_mean(x, cols, signs, means, flip_signs=False):
"""Shift the target column by the mean of the offset columns."""
for col in cols:
s = -signs[col] if flip_signs else signs[col]
x[target_column] = x[target_column] + s * means[col]
return x[target_column]

def shift_by_row(x, cols, signs):
"""Apply the offsets to the target column."""
for col in cols:
x[target_column] = x[target_column] + signs[col] * x[col]
return x[target_column]

# apply offset from the reduced columns
df[target_column] = df.map_partitions(
shift_by_mean,
cols=[col for col, red in zip(offset_columns, reductions) if red],
signs=signs_dict,
means=means,
meta=df[target_column].dtype,
)
for col, red in zip(offset_columns, reductions):
if red == "mean":
df[target_column] = df[target_column] + signs_dict[col] * df[col].mean()

# apply offset from the offset columns
df[target_column] = df.map_partitions(
shift_by_row,
cols=[col for col, red in zip(offset_columns, reductions) if not red],
signs=signs_dict,
meta=df[target_column].dtype,
)
for col, red in zip(offset_columns, reductions):
if not red:
df[target_column] = df[target_column] + signs_dict[col] * df[col]

# compensate shift from the preserved mean columns
if any(preserve_mean):
df[target_column] = df.map_partitions(
shift_by_mean,
cols=[col for col, pmean in zip(offset_columns, preserve_mean) if pmean],
signs=signs_dict,
means=means,
flip_signs=True,
meta=df[target_column].dtype,
)
for col, pmean in zip(offset_columns, preserve_mean):
if pmean:
df[target_column] = df[target_column] - signs_dict[col] * df[col].mean()

return df
144 changes: 85 additions & 59 deletions sed/loader/mpes/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,26 @@ def hdf5_to_dataframe(
seach_pattern="Stream",
)

channel_list = []
electron_channels = []
column_names = []

for name, channel in channels.items():
if (
channel["format"] == "per_electron"
and channel["dataset_key"] in test_proc
or channel["format"] == "per_file"
and channel["dataset_key"] in test_proc.attrs
):
channel_list.append(channel)
column_names.append(name)
else:
print(
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
"Skipping the channel.",
)
if channel["format"] == "per_electron":
if channel["dataset_key"] in test_proc:
electron_channels.append(channel)
column_names.append(name)
else:
print(
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
"Skipping the channel.",
)

if time_stamps:
column_names.append(time_stamp_alias)

test_array = hdf5_to_array(
h5file=test_proc,
channels=channel_list,
channels=electron_channels,
time_stamps=time_stamps,
ms_markers_key=ms_markers_key,
first_event_time_stamp_key=first_event_time_stamp_key,
Expand All @@ -101,7 +97,7 @@ def hdf5_to_dataframe(
da.from_delayed(
dask.delayed(hdf5_to_array)(
h5file=h5py.File(f),
channels=channel_list,
channels=electron_channels,
time_stamps=time_stamps,
ms_markers_key=ms_markers_key,
first_event_time_stamp_key=first_event_time_stamp_key,
Expand All @@ -113,7 +109,25 @@ def hdf5_to_dataframe(
]
array_stack = da.concatenate(arrays, axis=1).T

return ddf.from_dask_array(array_stack, columns=column_names)
dataframe = ddf.from_dask_array(array_stack, columns=column_names)

for name, channel in channels.items():
if channel["format"] == "per_file":
if channel["dataset_key"] in test_proc.attrs:
values = [float(get_attribute(h5py.File(f), channel["dataset_key"])) for f in files]
delayeds = [
add_value(partition, name, value)
for partition, value in zip(dataframe.partitions, values)
]
dataframe = ddf.from_delayed(delayeds)

else:
print(
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
"Skipping the channel.",
)

return dataframe


def hdf5_to_timed_dataframe(
Expand Down Expand Up @@ -156,30 +170,26 @@ def hdf5_to_timed_dataframe(
seach_pattern="Stream",
)

channel_list = []
electron_channels = []
column_names = []

for name, channel in channels.items():
if (
channel["format"] == "per_electron"
and channel["dataset_key"] in test_proc
or channel["format"] == "per_file"
and channel["dataset_key"] in test_proc.attrs
):
channel_list.append(channel)
column_names.append(name)
else:
print(
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
"Skipping the channel.",
)
if channel["format"] == "per_electron":
if channel["dataset_key"] in test_proc:
electron_channels.append(channel)
column_names.append(name)
else:
print(
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
"Skipping the channel.",
)

if time_stamps:
column_names.append(time_stamp_alias)

test_array = hdf5_to_timed_array(
h5file=test_proc,
channels=channel_list,
channels=electron_channels,
time_stamps=time_stamps,
ms_markers_key=ms_markers_key,
first_event_time_stamp_key=first_event_time_stamp_key,
Expand All @@ -190,7 +200,7 @@ def hdf5_to_timed_dataframe(
da.from_delayed(
dask.delayed(hdf5_to_timed_array)(
h5file=h5py.File(f),
channels=channel_list,
channels=electron_channels,
time_stamps=time_stamps,
ms_markers_key=ms_markers_key,
first_event_time_stamp_key=first_event_time_stamp_key,
Expand All @@ -202,7 +212,41 @@ def hdf5_to_timed_dataframe(
]
array_stack = da.concatenate(arrays, axis=1).T

return ddf.from_dask_array(array_stack, columns=column_names)
dataframe = ddf.from_dask_array(array_stack, columns=column_names)

for name, channel in channels.items():
if channel["format"] == "per_file":
if channel["dataset_key"] in test_proc.attrs:
values = [float(get_attribute(h5py.File(f), channel["dataset_key"])) for f in files]
delayeds = [
add_value(partition, name, value)
for partition, value in zip(dataframe.partitions, values)
]
dataframe = ddf.from_delayed(delayeds)

else:
print(
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
"Skipping the channel.",
)

return dataframe


@dask.delayed
def add_value(partition: ddf.DataFrame, name: str, value: float) -> ddf.DataFrame:
"""Dask delayed helper function to add a value to each dataframe partition
Args:
partition (ddf.DataFrame): Dask dataframe partition
name (str): Name of the column to add
value (float): value to add to this partition
Returns:
ddf.DataFrame: Dataframe partition with added column
"""
partition[name] = value
return partition


def get_datasets_and_aliases(
Expand Down Expand Up @@ -256,7 +300,7 @@ def hdf5_to_array(
Args:
h5file (h5py.File):
hdf5 file handle to read from
electron_channels (Sequence[Dict[str, any]]):
channels (Sequence[Dict[str, any]]):
channel dicts containing group names and types to read.
time_stamps (bool, optional):
Option to calculate time stamps. Defaults to False.
Expand All @@ -270,40 +314,25 @@ def hdf5_to_array(
"""

# Delayed array for loading an HDF5 file of reasonable size (e.g. < 1GB)

# determine group length from per_electron column:
nelectrons = 0
for channel in channels:
if channel["format"] == "per_electron":
nelectrons = len(h5file[channel["dataset_key"]])
break
if nelectrons == 0:
raise ValueError("No 'per_electron' columns defined, or no hits found in file.")

# Read out groups:
data_list = []
for channel in channels:
if channel["format"] == "per_electron":
g_dataset = np.asarray(h5file[channel["dataset_key"]])
elif channel["format"] == "per_file":
value = float(get_attribute(h5file, channel["dataset_key"]))
g_dataset = np.asarray([value] * nelectrons)
else:
raise ValueError(
f"Invalid 'format':{channel['format']} for channel {channel['dataset_key']}.",
)
if "data_type" in channel.keys():
g_dataset = g_dataset.astype(channel["data_type"])
if "dtype" in channel.keys():
g_dataset = g_dataset.astype(channel["dtype"])
else:
g_dataset = g_dataset.astype("float32")
if len(g_dataset) != nelectrons:
raise ValueError(f"Inconsistent entries found for channel {channel['dataset_key']}.")
data_list.append(g_dataset)

# calculate time stamps
if time_stamps:
# create target array for time stamps
time_stamp_data = np.zeros(nelectrons)
time_stamp_data = np.zeros(len(data_list[0]))
# the ms marker contains a list of events that occurred at full ms intervals.
# It's monotonically increasing, and can contain duplicates
ms_marker = np.asarray(h5file[ms_markers_key])
Expand Down Expand Up @@ -357,7 +386,7 @@ def hdf5_to_timed_array(
Args:
h5file (h5py.File):
hdf5 file handle to read from
electron_channels (Sequence[Dict[str, any]]):
channels (Sequence[Dict[str, any]]):
channel dicts containing group names and types to read.
time_stamps (bool, optional):
Option to calculate time stamps. Defaults to False.
Expand All @@ -382,15 +411,12 @@ def hdf5_to_timed_array(
g_dataset = np.asarray(h5file[channel["dataset_key"]])
for i, point in enumerate(ms_marker):
timed_dataset[i] = g_dataset[int(point) - 1]
elif channel["format"] == "per_file":
value = float(get_attribute(h5file, channel["dataset_key"]))
timed_dataset[:] = value
else:
raise ValueError(
f"Invalid 'format':{channel['format']} for channel {channel['dataset_key']}.",
)
if "data_type" in channel.keys():
timed_dataset = timed_dataset.astype(channel["data_type"])
if "dtype" in channel.keys():
timed_dataset = timed_dataset.astype(channel["dtype"])
else:
timed_dataset = timed_dataset.astype("float32")

Expand Down

0 comments on commit 0004113

Please sign in to comment.