diff --git a/src/lcm/grids.py b/src/lcm/grids.py index ca23c00..2e62bae 100644 --- a/src/lcm/grids.py +++ b/src/lcm/grids.py @@ -46,19 +46,10 @@ def __init__(self, category_class: type) -> None: values. """ - if not is_dataclass(category_class): - raise GridInitializationError( - "category_class must be a dataclass with scalar int or float fields, " - f"but is {category_class}." - ) + _validate_discrete_grid(category_class) names_and_values = _get_field_names_and_values(category_class) - errors = _validate_discrete_grid(names_and_values) - if errors: - msg = format_messages(errors) - raise GridInitializationError(msg) - self.__categories = list(names_and_values.keys()) self.__codes = list(names_and_values.values()) @@ -159,17 +150,21 @@ def get_coordinate(self, value: Scalar) -> Scalar: # ====================================================================================== -def _validate_discrete_grid(names_and_values: dict[str, Any]) -> list[str]: +def _validate_discrete_grid(category_class: type) -> None: """Validate the field names and values of the category_class passed to DiscreteGrid. Args: - names_and_values: A dictionary with the field names as keys and the field - values as values. - - Returns: - list[str]: A list of error messages. + category_class: The class with mappings of names to codes. """ + if not is_dataclass(category_class): + raise GridInitializationError( + "category_class must be a dataclass with scalar int or float fields, " + f"but is {category_class}." + ) + + names_and_values = _get_field_names_and_values(category_class) + error_messages = [] if not len(names_and_values) > 0: @@ -198,7 +193,9 @@ def _validate_discrete_grid(names_and_values: dict[str, Any]) -> list[str]: f"{set(duplicated_values)}" ) - return error_messages + if error_messages: + msg = format_messages(error_messages) + raise GridInitializationError(msg) def _get_field_names_and_values(dc: type) -> dict[str, Any]: