Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unified typing of df_partitions #298

Merged
merged 5 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ name: linting
# Triggers the workflow on push for all branches
on:
push:
branches: [ main ]
branches:
- '*'
paths-ignore:
pyproject.toml

Expand Down
42 changes: 24 additions & 18 deletions sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from sed.calibrator import MomentumCorrector
from sed.core.config import parse_config
from sed.core.config import save_config
from sed.core.dfops import apply_filter
from sed.core.dfops import add_time_stamped_data
from sed.core.dfops import apply_filter
from sed.core.dfops import apply_jitter
from sed.core.metadata import MetaHandler
from sed.diagnostics import grid_histogram
Expand Down Expand Up @@ -453,7 +453,7 @@ def filter_column(
# 1. Bin raw detector data for distortion correction
def bin_and_load_momentum_calibration(
self,
df_partitions: int = 100,
df_partitions: Union[int, Sequence[int]] = 100,
axes: List[str] = None,
bins: List[int] = None,
ranges: Sequence[Tuple[float, float]] = None,
Expand All @@ -467,8 +467,8 @@ def bin_and_load_momentum_calibration(
interactive view, and load it into the momentum corrector class.

Args:
df_partitions (int, optional): Number of dataframe partitions to use for
the initial binning. Defaults to 100.
df_partitions (Union[int, Sequence[int]], optional): Number of dataframe partitions
to use for the initial binning. Defaults to 100.
axes (List[str], optional): Axes to bin.
Defaults to config["momentum"]["axes"].
bins (List[int], optional): Bin numbers to use for binning.
Expand Down Expand Up @@ -1792,7 +1792,7 @@ def add_time_stamped_data(

def pre_binning(
self,
df_partitions: int = 100,
df_partitions: Union[int, Sequence[int]] = 100,
axes: List[str] = None,
bins: List[int] = None,
ranges: Sequence[Tuple[float, float]] = None,
Expand All @@ -1801,8 +1801,8 @@ def pre_binning(
"""Function to do an initial binning of the dataframe loaded to the class.

Args:
df_partitions (int, optional): Number of dataframe partitions to use for
the initial binning. Defaults to 100.
df_partitions (Union[int, Sequence[int]], optional): Number of dataframe partitions to
use for the initial binning. Defaults to 100.
axes (List[str], optional): Axes to bin.
Defaults to config["momentum"]["axes"].
bins (List[int], optional): Bin numbers to use for binning.
Expand Down Expand Up @@ -1895,7 +1895,7 @@ def compute(
- **threadpool_api**: The API to use for multiprocessing. "blas",
"openmp" or None. See ``threadpool_limit`` for details. Defaults to
config["binning"]["threadpool_API"].
- **df_partitions**: A range or list of dataframe partitions, or the
- **df_partitions**: A sequence of dataframe partitions, or the
number of the dataframe partitions to use. Defaults to all partitions.

Additional kwds are passed to ``bin_dataframe``.
Expand All @@ -1921,11 +1921,14 @@ def compute(
"threadpool_API",
self._config["binning"]["threadpool_API"],
)
df_partitions = kwds.pop("df_partitions", None)
df_partitions: Union[int, Sequence[int]] = kwds.pop("df_partitions", None)
if isinstance(df_partitions, int):
df_partitions = slice(
0,
min(df_partitions, self._dataframe.npartitions),
df_partitions = cast(
Sequence[int],
np.arange(
0,
min(df_partitions, self._dataframe.npartitions),
),
)
if df_partitions is not None:
dataframe = self._dataframe.partitions[df_partitions]
Expand Down Expand Up @@ -2009,8 +2012,8 @@ def get_normalization_histogram(
dataframe, rather than the timed dataframe. Defaults to False.
**kwds: Keyword arguments:

-df_partitions (int, optional): Number of dataframe partitions to use.
Defaults to all.
- **df_partitions**: A sequence of dataframe partitions, or the
number of the dataframe partitions to use. Defaults to all partitions.

Raises:
ValueError: Raised if no data are binned.
Expand All @@ -2028,11 +2031,14 @@ def get_normalization_histogram(
if axis not in self._binned.coords:
raise ValueError(f"Axis '{axis}' not found in binned data!")

df_partitions: Union[int, slice] = kwds.pop("df_partitions", None)
df_partitions: Union[int, Sequence[int]] = kwds.pop("df_partitions", None)
if isinstance(df_partitions, int):
df_partitions = slice(
0,
min(df_partitions, self._dataframe.npartitions),
df_partitions = cast(
Sequence[int],
np.arange(
0,
min(df_partitions, self._dataframe.npartitions),
),
)

if use_time_stamps or self._timed_dataframe is None:
Expand Down