From df8456330447a61543d1fb62024e57e2640b66b4 Mon Sep 17 00:00:00 2001 From: Thomas M Kehrenberg Date: Thu, 12 Oct 2023 14:52:02 +0200 Subject: [PATCH] Use a match statement instead of an elif-cascade --- ranzen/torch/loss.py | 17 +++++++++-------- ranzen/wandb.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ranzen/torch/loss.py b/ranzen/torch/loss.py index 266a11f4..7926f292 100644 --- a/ranzen/torch/loss.py +++ b/ranzen/torch/loss.py @@ -31,14 +31,15 @@ class ReductionType(Enum): def reduce(losses: Tensor, reduction_type: ReductionType | str) -> Tensor: if isinstance(reduction_type, str): reduction_type = str_to_enum(str_=reduction_type, enum=ReductionType) - if reduction_type is ReductionType.mean: - return losses.mean() - elif reduction_type is ReductionType.batch_mean: - return losses.sum() / losses.size(0) - elif reduction_type is ReductionType.sum: - return losses.sum() - elif reduction_type is ReductionType.none: - return losses + match reduction_type: + case ReductionType.mean: + return losses.mean() + case ReductionType.batch_mean: + return losses.sum() / losses.size(0) + case ReductionType.sum: + return losses.sum() + case ReductionType.none: + return losses raise TypeError( f"Received invalid type '{type(reduction_type)}' for argument 'reduction_type'." ) diff --git a/ranzen/wandb.py b/ranzen/wandb.py index d525f315..09477a94 100644 --- a/ranzen/wandb.py +++ b/ranzen/wandb.py @@ -61,7 +61,7 @@ def modify_config( logger.info(f"Changed config for {i} runs.") @staticmethod - def _runs_to_df(runs: Sequence[wandb.sdk.wandb_run.Run]) -> pd.DataFrame: # type: ignore + def _runs_to_df(runs: Sequence[wandb.sdk.wandb_run.Run]) -> pd.DataFrame: # pyright: ignore summary_list = [] config_list = [] name_list = []