Skip to content

Commit

Permalink
fix & extend StrEnum.from_str (#99)
Browse files Browse the repository at this point in the history
* fix & extend StrEnum.from_str
* chlog & warning
* mypy

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Feb 7, 2023
1 parent 9eab563 commit b76c85a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allow frozen dataclasses in `apply_to_collection` ([#98](https://github.com/Lightning-AI/utilities/pull/98))


- Extended `StrEnum.from_str` with optional raising ValueError ([#99](https://github.com/Lightning-AI/utilities/pull/99))


### Changed

- CI/docs: allow passing env. variables ([#96](https://github.com/Lightning-AI/utilities/pull/96))


### Fixed

-
- Fixed `StrEnum.from_str` with source as key ([#99](https://github.com/Lightning-AI/utilities/pull/99))


## [0.6.0] - 2023-01-23
Expand Down
68 changes: 61 additions & 7 deletions src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
import warnings
from enum import Enum
from typing import Optional

Expand All @@ -16,19 +17,72 @@ class StrEnum(str, Enum):
... t2 = "T-2"
>>> MySE("T-1") == MySE.t1
True
>>> MySE.from_str("t-2") == MySE.t2
>>> MySE.from_str("t-2", source="value") == MySE.t2
True
"""

@classmethod
def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]:
for st, val in cls.__members__.items():
if source in ("key", "any") and st.lower() == value.lower():
return cls[st]
if source in ("value", "any") and val.lower() == value.lower():
return cls[st]
def from_str(
cls, value: str, source: Literal["key", "value", "any"] = "key", strict: bool = False
) -> Optional["StrEnum"]:
"""Create StrEnum from a sting matching the key or value.
Args:
value: matching string
source: compare with:
- ``"key"``: validates only with Enum keys, typical alphanumeric with "_"
- ``"value"``: validates only with Enum values, could be any string
- ``"key"``: validates with any key or value, but key has priority
strict: allow not matching string and returns None; if false raises exceptions
Raises:
ValueError:
if requested string does not match any option based on selected source and use ``"strict=True"``
UserWarning:
if requested string does not match any option based on selected source and use ``"strict=False"``
Example:
>>> class MySE(StrEnum):
... t1 = "T-1"
... t2 = "T-2"
>>> MySE.from_str("t-1", source="key")
>>> MySE.from_str("t-2", source="value")
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any", strict=True)
Traceback (most recent call last):
...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
"""
allowed = cls._allowed_matches(source)
if strict and not any(enum_.lower() == value.lower() for enum_ in allowed):
raise ValueError(f"Invalid match: expected one of {allowed}, but got {value}.")

if source in ("key", "any"):
for enum_key in cls.__members__.keys():
if enum_key.lower() == value.lower():
return cls[enum_key]
if source in ("value", "any"):
for enum_key, enum_val in cls.__members__.items():
if enum_val == value:
return cls[enum_key]

warnings.warn(UserWarning(f"Invalid string: expected one of {allowed}, but got {value}."))
return None

@classmethod
def _allowed_matches(cls, source: str) -> list:
keys, vals = [], []
for enum_key, enum_val in cls.__members__.items():
keys.append(enum_key)
vals.append(enum_val.value)
if source == "key":
return keys
if source == "value":
return vals
return keys + vals

def __eq__(self, other: object) -> bool:
if isinstance(other, Enum):
other = other.value
Expand Down

0 comments on commit b76c85a

Please sign in to comment.