Skip to content

Commit

Permalink
Use a match statement instead of an elif-cascade
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Oct 12, 2023
1 parent 37d84cf commit df84563
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
17 changes: 9 additions & 8 deletions ranzen/torch/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ class ReductionType(Enum):
def reduce(losses: Tensor, reduction_type: ReductionType | str) -> Tensor:
if isinstance(reduction_type, str):
reduction_type = str_to_enum(str_=reduction_type, enum=ReductionType)
if reduction_type is ReductionType.mean:
return losses.mean()
elif reduction_type is ReductionType.batch_mean:
return losses.sum() / losses.size(0)
elif reduction_type is ReductionType.sum:
return losses.sum()
elif reduction_type is ReductionType.none:
return losses
match reduction_type:
case ReductionType.mean:
return losses.mean()
case ReductionType.batch_mean:
return losses.sum() / losses.size(0)
case ReductionType.sum:
return losses.sum()
case ReductionType.none:
return losses
raise TypeError(
f"Received invalid type '{type(reduction_type)}' for argument 'reduction_type'."
)
Expand Down
2 changes: 1 addition & 1 deletion ranzen/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def modify_config(
logger.info(f"Changed config for {i} runs.")

@staticmethod
def _runs_to_df(runs: Sequence[wandb.sdk.wandb_run.Run]) -> pd.DataFrame: # type: ignore
def _runs_to_df(runs: Sequence[wandb.sdk.wandb_run.Run]) -> pd.DataFrame: # pyright: ignore
summary_list = []
config_list = []
name_list = []
Expand Down

0 comments on commit df84563

Please sign in to comment.