Skip to content

Commit

Permalink
fix merge issues
Browse files Browse the repository at this point in the history
  • Loading branch information
rettigl committed Aug 12, 2023
1 parent c7655be commit d3076e6
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 59 deletions.
71 changes: 27 additions & 44 deletions sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,15 @@ def load(
# In that case, we copy the whole provided base folder tree, and pass the copied
# version to the loader as base folder to look for the runs.
if folder is not None:
dataframe, metadata = self.loader.read_dataframe(
dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
folders=cast(str, self.cpy(folder)),
runs=runs,
metadata=metadata,
collect_metadata=collect_metadata,
**kwds,
)
else:
dataframe, metadata = self.loader.read_dataframe(
dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
runs=runs,
metadata=metadata,
collect_metadata=collect_metadata,
Expand All @@ -320,25 +320,20 @@ def load(
collect_metadata=collect_metadata,
**kwds,
)
self._dataframe = dataframe
self._timed_dataframe = timed_dataframe
self._files = self.loader.files
elif files is not None:
dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
files=cast(List[str], self.cpy(files)),
metadata=metadata,
collect_metadata=collect_metadata,
**kwds,
)
self._dataframe = dataframe
self._timed_dataframe = timed_dataframe
self._files = self.loader.files
else:
raise ValueError(
"Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
)

self._dataframe = dataframe
self._timed_dataframe = timed_dataframe
self._files = self.loader.files

for key in metadata:
Expand Down Expand Up @@ -1431,11 +1426,9 @@ def compute(
print(
f"Calculate normalization histogram for axis '{axis}'...",
)
self._normalization_histogram = (
self.get_normalization_histogram(
axis=axis,
df_partitions=df_partitions,
)
self._normalization_histogram = self.get_normalization_histogram(
axis=axis,
df_partitions=df_partitions,
)
# if the axes are named correctly, xarray figures out the normalization correctly
self._normalized = self._binned / self._normalization_histogram
Expand All @@ -1457,9 +1450,7 @@ def compute(
)

self._normalized.attrs["units"] = "counts/second"
self._normalized.attrs[
"long_name"
] = "photoelectron counts per second"
self._normalized.attrs["long_name"] = "photoelectron counts per second"
self._normalized.attrs["metadata"] = self._attributes.metadata

return self._normalized
Expand Down Expand Up @@ -1510,41 +1501,33 @@ def get_normalization_histogram(

if use_time_stamps or self._timed_dataframe is None:
if df_partitions is not None:
self._normalization_histogram = (
normalization_histogram_from_timestamps(
self._dataframe.partitions[df_partitions],
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["time_stamp_alias"],
)
self._normalization_histogram = normalization_histogram_from_timestamps(
self._dataframe.partitions[df_partitions],
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["time_stamp_alias"],
)
else:
self._normalization_histogram = (
normalization_histogram_from_timestamps(
self._dataframe,
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["time_stamp_alias"],
)
self._normalization_histogram = normalization_histogram_from_timestamps(
self._dataframe,
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["time_stamp_alias"],
)
else:
if df_partitions is not None:
self._normalization_histogram = (
normalization_histogram_from_timed_dataframe(
self._timed_dataframe.partitions[df_partitions],
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["timed_dataframe_unit_time"],
)
self._normalization_histogram = normalization_histogram_from_timed_dataframe(
self._timed_dataframe.partitions[df_partitions],
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["timed_dataframe_unit_time"],
)
else:
self._normalization_histogram = (
normalization_histogram_from_timed_dataframe(
self._timed_dataframe,
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["timed_dataframe_unit_time"],
)
self._normalization_histogram = normalization_histogram_from_timed_dataframe(
self._timed_dataframe,
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["timed_dataframe_unit_time"],
)

return self._normalization_histogram
Expand Down
2 changes: 1 addition & 1 deletion tests/calibrator/test_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import pytest

from sed.calibrator.momentum import MomentumCorrector
from sed.core.config import parse_config
from sed.core import SedProcessor
from sed.core.config import parse_config
from sed.loader.loader_interface import get_loader

# pylint: disable=duplicate-code
Expand Down
20 changes: 6 additions & 14 deletions tests/loader/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,36 +95,28 @@ def test_has_correct_read_dataframe_func(loader, read_type):
extension=supported_file_type,
)
if read_type == "one_file":
(
loaded_dataframe,
_,
loaded_metadata,
) = loader.read_dataframe(
loaded_dataframe, _, loaded_metadata = loader.read_dataframe(
files=input_files[0],
ftype=supported_file_type,
collect_metadata=False,
)
expected_size = 1
elif read_type == "files":
(
loaded_dataframe,
_,
loaded_metadata,
) = loader.read_dataframe(
loaded_dataframe, _, loaded_metadata = loader.read_dataframe(
files=list(input_files),
ftype=supported_file_type,
collect_metadata=False,
)
expected_size = len(input_files)
elif read_type == "one_folder":
loaded_dataframe, loaded_metadata = loader.read_dataframe(
loaded_dataframe, _, loaded_metadata = loader.read_dataframe(
folders=input_folder,
ftype=supported_file_type,
collect_metadata=False,
)
expected_size = len(input_files)
elif read_type == "folders":
loaded_dataframe, loaded_metadata = loader.read_dataframe(
loaded_dataframe, _, loaded_metadata = loader.read_dataframe(
folders=[input_folder],
ftype=supported_file_type,
collect_metadata=False,
Expand All @@ -133,7 +125,7 @@ def test_has_correct_read_dataframe_func(loader, read_type):
elif read_type == "one_run":
if runs[get_loader_name_from_loader_object(loader)] is None:
pytest.skip("Not implemented")
loaded_dataframe, loaded_metadata = loader.read_dataframe(
loaded_dataframe, _, loaded_metadata = loader.read_dataframe(
runs=runs[get_loader_name_from_loader_object(loader)][0],
ftype=supported_file_type,
collect_metadata=False,
Expand All @@ -142,7 +134,7 @@ def test_has_correct_read_dataframe_func(loader, read_type):
elif read_type == "runs":
if runs[get_loader_name_from_loader_object(loader)] is None:
pytest.skip("Not implemented")
loaded_dataframe, loaded_metadata = loader.read_dataframe(
loaded_dataframe, _, loaded_metadata = loader.read_dataframe(
runs=runs[get_loader_name_from_loader_object(loader)],
ftype=supported_file_type,
collect_metadata=False,
Expand Down

0 comments on commit d3076e6

Please sign in to comment.