From 28997b2d0d819ea222458a5fb7ccd7bddefbb93b Mon Sep 17 00:00:00 2001 From: wbenbihi Date: Sun, 21 Aug 2022 21:52:33 +0800 Subject: [PATCH] [ADD][FEAT] Handlers --- hourglass_tensorflow/handlers/__init__.py | 0 hourglass_tensorflow/handlers/data.py | 185 ++++++++++++++ hourglass_tensorflow/handlers/dataset.py | 288 ++++++++++++++++++++++ hourglass_tensorflow/handlers/engines.py | 129 ++++++++++ hourglass_tensorflow/handlers/meta.py | 54 ++++ hourglass_tensorflow/handlers/model.py | 0 hourglass_tensorflow/handlers/train.py | 0 7 files changed, 656 insertions(+) create mode 100644 hourglass_tensorflow/handlers/__init__.py create mode 100644 hourglass_tensorflow/handlers/data.py create mode 100644 hourglass_tensorflow/handlers/dataset.py create mode 100644 hourglass_tensorflow/handlers/engines.py create mode 100644 hourglass_tensorflow/handlers/meta.py create mode 100644 hourglass_tensorflow/handlers/model.py create mode 100644 hourglass_tensorflow/handlers/train.py diff --git a/hourglass_tensorflow/handlers/__init__.py b/hourglass_tensorflow/handlers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hourglass_tensorflow/handlers/data.py b/hourglass_tensorflow/handlers/data.py new file mode 100644 index 0000000..371d842 --- /dev/null +++ b/hourglass_tensorflow/handlers/data.py @@ -0,0 +1,185 @@ +import os +import itertools +from abc import abstractmethod +from glob import glob +from typing import List + +import pandas as pd + +from hourglass_tensorflow.handlers.meta import _HTFHandler +from hourglass_tensorflow.utils._errors import BadConfigurationError +from hourglass_tensorflow.types.config.data import HTFDataInput +from hourglass_tensorflow.types.config.data import HTFDataConfig +from hourglass_tensorflow.types.config.data import HTFDataOutput + +# region Abstract Class + + +class _HTFDataHandler(_HTFHandler): + def __init__(self, config: HTFDataConfig, *args, **kwargs) -> None: + super().__init__(config=config, *args, **kwargs) + + @property + def config(self) -> HTFDataConfig: + return self._config + + @property + def input_cfg(self) -> HTFDataInput: + return self.config.input + + @property + def output_cfg(self) -> HTFDataOutput: + return self.config.output + + @abstractmethod + def prepare_input(self, *args, **kwargs) -> None: + raise NotImplementedError + + @abstractmethod + def prepare_output(self, *args, **kwargs) -> None: + raise NotImplementedError + + def run(self, *args, **kwargs) -> None: + self.prepare_input(*args, **kwargs) + self.prepare_output(*args, **kwargs) + + +# enregion + +# region Handler + + +class HTFDataHandler(_HTFDataHandler): + def _list_input_images(self) -> None: + """_summary_ + + Raises: + BadConfigurationError: _description_ + """ + if not os.path.exists(self.input_cfg.source): + raise BadConfigurationError( + f"Unable to find source folder {self.input_cfg.source}" + ) + self.info( + f"Listing {self.input_cfg.extensions} images in {self.input_cfg.source}" + ) + self._metadata.available_images = set( + itertools.chain( + *[ + glob(os.path.join(self.input_cfg.source, f"*.{ext}")) + for ext in self.input_cfg.extensions + ] + ) + ) + + def _validate_joints(self, _error: bool = True) -> bool: + conditions = [len(self.output_cfg.joints.names) == self.output_cfg.joints.num] + if not all(conditions): + if _error: + raise BadConfigurationError("Joints properties are not valid") + return False + return True + + def _valid_labels_header(self, df: pd.DataFrame, _error: bool = False) -> bool: + # Check if numbers of columns are valid + n_joint = self.output_cfg.joints.num + # Check if columns names are valid + naming_convention = self.output_cfg.joints.naming_convention + suffixes = self.output_cfg.joints.suffixes + headers = self.output_cfg.prefix_columns + [ + naming_convention.format(JOINT_ID=jid, SUFFIX=suffix) + for jid in range(n_joint) + for suffix in suffixes.__dict__.values() + ] + if set(headers).difference(set(list(df.columns))): + if _error: + raise BadConfigurationError( + f"Columns' name does not match configuration\n\tEXPECTED:\n\t{headers}\n\tRECEIVED:\n\t{list(df.columns)}\n\tMISSING COLUMNS:\n\t{set(headers).difference(set(list(df.columns)))}" + ) + return False + # If everything is good we store the expected headers in _metadata + self._metadata.label_headers = headers + return True + + def _load_labels(self) -> pd.DataFrame: + self.info(f"Reading labels from {self.output_cfg.source}") + ## Check if the file extension is in [.json, .csv] + if self.output_cfg.source.endswith(".json"): + self._metadata.label_type = "json" + labels = pd.read_json(self.output_cfg.source, orient="records") + elif self.output_cfg.source.endswith(".csv"): + self._metadata.label_type = "csv" + labels = pd.read_csv(self.output_cfg.source) + else: + raise BadConfigurationError( + f"{self.output_cfg.source} should be of type .json or .csv" + ) + if not isinstance(labels, pd.DataFrame): + raise BadConfigurationError( + f"{self.output_cfg.source} not parsable as pandas.DataFrame" + ) + return labels + + def _prefix_images(self, df: pd.DataFrame) -> pd.DataFrame: + # Now we also prefix the image column with the image folder + # in case the source_prefix attribute is set to false + folder_prefix = self.input_cfg.source + source_column = self.output_cfg.column_source + df = df.assign( + **{ + source_column: df[source_column].apply( + lambda x: os.path.join(folder_prefix, x) + ) + } + ) + return df + + def _read_labels(self, _error: bool = False) -> bool: + # Check if data.output.source exists ? + if not os.path.exists(self.output_cfg.source): + raise BadConfigurationError(f"Unable to find {self.output_cfg.source}") + # Read Data + labels = self._load_labels() + # Validate expected labels columns + if not self._valid_labels_header(labels, _error=_error): + self.error("Labels are not matching") + return False + self._labels_df: pd.DataFrame = labels[self.meta.label_headers] + self._metadata.label_mapper = { + label: i for i, label in enumerate(self.meta.label_headers) + } + if not self.output_cfg.source_prefixed: + self._labels_df = self._prefix_images(self._labels_df) + return True + + def _get_joint_columns(self) -> List[str]: + JOINT_CFG = self.config.output.joints + num_joints = JOINT_CFG.num + dynamic_fields = JOINT_CFG.dynamic_fields + data_format = JOINT_CFG.format + index_field = JOINT_CFG.format.id_field + naming = JOINT_CFG.naming_convention + groups = itertools.product( + *[list(getattr(data_format, g).values()) for g in dynamic_fields] + ) + named_groups = [ + {dynamic_fields[i]: el for i, el in enumerate(group)} for group in groups + ] + return [ + naming.format(**{**group, **{index_field: joint_idx}}) + for joint_idx in range(num_joints) + for group in named_groups + ] + + def prepare_input(self) -> None: + # List files in Input Source Folder + self._list_input_images() + + def prepare_output(self, _error: bool = True) -> None: + # Read the label file + self._metadata.joint_columns = self._get_joint_columns() + self._validate_joints(_error=_error) + self._read_labels(_error=_error) + + +# endregion diff --git a/hourglass_tensorflow/handlers/dataset.py b/hourglass_tensorflow/handlers/dataset.py new file mode 100644 index 0000000..d80e647 --- /dev/null +++ b/hourglass_tensorflow/handlers/dataset.py @@ -0,0 +1,288 @@ +from abc import abstractmethod +from typing import Any +from typing import Set +from typing import Dict +from typing import Type +from typing import Tuple +from typing import Union +from typing import Iterable +from typing import Optional + +import numpy as np +import pandas as pd +import tensorflow as tf + +from hourglass_tensorflow.utils.sets import split_train_test +from hourglass_tensorflow.handlers.meta import _HTFHandler +from hourglass_tensorflow.handlers.engines import ENGINES +from hourglass_tensorflow.handlers.engines import HTFEngine +from hourglass_tensorflow.types.config.dataset import HTFDatasetBBox +from hourglass_tensorflow.types.config.dataset import HTFDatasetSets +from hourglass_tensorflow.types.config.dataset import HTFDatasetConfig +from hourglass_tensorflow.types.config.dataset import HTFDatasetHeatmap +from hourglass_tensorflow.datasets.transformation import tf_train_map_stacks +from hourglass_tensorflow.datasets.transformation import tf_train_map_heatmaps +from hourglass_tensorflow.datasets.transformation import tf_train_map_squarify +from hourglass_tensorflow.datasets.transformation import tf_train_map_normalize +from hourglass_tensorflow.datasets.transformation import tf_train_map_build_slice +from hourglass_tensorflow.datasets.transformation import tf_train_map_resize_data + +# region Abstract Class + +HTFDataTypes = Union[np.ndarray, pd.DataFrame] +ImageSetsType = Tuple[Optional[Set[str]], Optional[Set[str]], Optional[Set[str]]] + + +class _HTFDatasetHandler(_HTFHandler): + + _ENGINES: Dict[Any, Type[HTFEngine]] = ENGINES + ENGINES: Dict[Any, Type[HTFEngine]] = {} + + def __init__( + self, + data: HTFDataTypes, + config: HTFDatasetConfig, + *args, + **kwargs, + ) -> None: + super().__init__(config=config, *args, **kwargs) + self.data = data + self.engine: HTFEngine = self.select_engine(data) + + @property + def _engines(self) -> Dict[Type, Type[HTFEngine]]: + return {**self._ENGINES, **self.ENGINES} + + @property + def config(self) -> HTFDatasetConfig: + return self._config + + @property + def sets(self) -> HTFDatasetSets: + return self.config.sets + + @property + def bbox(self) -> HTFDatasetBBox: + return self.config.sets + + @property + def heatmap(self) -> HTFDatasetHeatmap: + return self.config.sets + + def select_engine(self, data: Any) -> HTFEngine: + try: + self.engine = self._engines[type(data)](metadata=self._metadata) + except KeyError: + raise KeyError(f"No engine available for type {type(data)}") + + @abstractmethod + def prepare_dataset(self, *args, **kwargs) -> None: + raise NotImplementedError + + @abstractmethod + def generate_datasets(self, *args, **kwargs) -> None: + raise NotImplementedError + + def run(self, *args, **kwargs) -> None: + self.prepare_dataset(*args, **kwargs) + self.generate_datasets(*args, **kwargs) + + +# enregion + +# region Handler + + +class HTFDatasetHandler(_HTFDatasetHandler): + @property + def has_train(self) -> bool: + return self.sets.train + + @property + def has_test(self) -> bool: + return self.sets.test + + @property + def has_validation(self) -> bool: + return self.sets.validation + + @property + def ratio_train(self) -> float: + return self.sets.ratio_train if self.has_train else 0.0 + + @property + def ratio_test(self) -> float: + return self.sets.ratio_test if self.has_test else 0.0 + + @property + def ratio_validation(self) -> float: + return self.sets.ratio_validation if self.has_validation else 0.0 + + def init_handler(self, *args, **kwargs) -> None: + self.splitted = False + # Init attributes + self._test_set: Optional[HTFDataTypes] = None + self._train_set: Optional[HTFDataTypes] = None + self._validation_set: Optional[HTFDataTypes] = None + self._test_dataset: Optional[tf.data.Dataset] = None + self._train_dataset: Optional[tf.data.Dataset] = None + self._validation_dataset: Optional[tf.data.Dataset] = None + self.kwargs = kwargs + + # region Prepare Dataset Hidden Methods + def _generate_image_sets(self, images: Set[str]) -> ImageSetsType: + # Generate Sets + train = set() + test = set() + validation = set() + if self.has_train: + # Has training samples + if self.has_test & self.has_validation: + # + Validation and Test + train, test = split_train_test( + images, self.ratio_train + self.ratio_validation + ) + train, validation = split_train_test( + train, self.ratio_train / (self.ratio_train + self.ratio_validation) + ) + elif self.has_validation: + train, validation = split_train_test(images, self.ratio_train) + elif self.has_test: + train, test = split_train_test(images, self.ratio_train) + else: + train = images + else: + if self.has_test & self.has_validation: + test, validation = split_train_test(images, self.ratio_test) + elif self.has_test: + test = images + else: + validation = images + return train, test, validation + + def _split_by_column(self) -> Tuple[HTFDataTypes, HTFDataTypes, HTFDataTypes]: + train = self.engine.filter_data( + data=self.data, column=self.sets.column_split, set_name=self.sets.value_test + ) + test = self.engine.filter_data( + data=self.data, + column=self.sets.column_split, + set_name=self.sets.value_train, + ) + validation = self.engine.filter_data( + data=self.data, + column=self.sets.column_split, + set_name=self.sets.value_validation, + ) + return train, test, validation + + def _split_by_ratio(self) -> Tuple[HTFDataTypes, HTFDataTypes, HTFDataTypes]: + # Get set of unique images + images = self.engine.get_images(data=self.data, column=self.config.column_image) + img_train, img_test, img_validation = self._generate_image_sets(images) + # Save on metadata + self._metadata.test_images = img_test + self._metadata.train_images = img_train + self._metadata.validation_images = img_validation + # Select Subsets within the main data collection + test = self.engine.select_subset_from_images( + data=self.data, image_set=img_test, column=self.config.column_image + ) + train = self.engine.select_subset_from_images( + data=self.data, image_set=img_train, column=self.config.column_image + ) + validation = self.engine.select_subset_from_images( + data=self.data, image_set=img_validation, column=self.config.column_image + ) + return train, test, validation + + def _split_sets(self) -> None: + if self.sets.split_by_column: + # Use a predefined columns as discriminant + train, test, validation = self._split_by_column() + else: + # Enable train/test split here + train, test, validation = self._split_by_ratio() + self._train_set = train + self._test_set = test + self._validation_set = validation + self.splitted = True + + def prepare_dataset(self, *args, **kwargs) -> None: + self._split_sets() + + # endregion + + # region Generate Datasets Hidden Methods + def _extract_columns_from_data( + self, data: HTFDataTypes + ) -> Tuple[Iterable, Iterable]: + # Extract coordinates + filenames = self.engine.to_list( + self.engine.get_columns(data=data, columns=self.config.column_image) + ) + coordinates = self.engine.to_list( + self.engine.get_columns(data=data, columns=self.meta.joint_columns) + ) + return filenames, coordinates + + def _create_dataset(self, dataset: HTFDataTypes) -> tf.data.Dataset: + return ( + tf.data.Dataset.from_tensor_slices( + self._extract_columns_from_data(dataset=dataset) + ) + .map( + # Load Images + tf_train_map_build_slice + ) + .map( + # Compute BBOX cropping + lambda img, coord, vis: tf_train_map_squarify( + img, + coord, + vis, + bbox_enabled=self.config.bbox.activate, + bbox_factor=self.config.bbox.factor, + ) + ) + .map( + # Resize Image + lambda img, coord, vis: tf_train_map_resize_data( + img, coord, vis, input_size=int(self.config.image_size) + ) + ) + .map( + # Get Heatmaps + lambda img, coord, vis: tf_train_map_heatmaps( + img, + coord, + vis, + output_size=int(self.config.heatmap.size), + stddev=self.config.heatmap.stddev, + ) + ) + .map( + # Normalize Data + lambda img, hms: tf_train_map_normalize( + img, + hms, + normalization=self.config.normalization, + ) + ) + .map( + # Stacks + lambda img, hms: tf_train_map_stacks( + img, + hms, + stacks=self.config.heatmap.stacks, + ) + ) + ) + + def generate_datasets(self, *args, **kwargs) -> None: + self._train_dataset = self._create_dataset(self._train_set) + self._test_dataset = self._create_dataset(self._test_set) + self._validation_dataset = self._create_dataset(self._validation_set) + + +# endregion diff --git a/hourglass_tensorflow/handlers/engines.py b/hourglass_tensorflow/handlers/engines.py new file mode 100644 index 0000000..2f99a56 --- /dev/null +++ b/hourglass_tensorflow/handlers/engines.py @@ -0,0 +1,129 @@ +from abc import ABC +from abc import abstractmethod +from abc import abstractstaticmethod +from typing import Any +from typing import Set +from typing import Dict +from typing import List +from typing import Type + +import numpy as np +import pandas as pd +import tensorflow as tf + +from hourglass_tensorflow.utils.object_logger import ObjectLogger +from hourglass_tensorflow.types.config.metadata import HTFMetadata + + +class HTFEngine(ABC, ObjectLogger): + FOR_TYPE = None + + def __init__( + self, metadata: HTFMetadata, verbose: bool = True, *args, **kwargs + ) -> None: + super().__init__(verbose=verbose, *args, **kwargs) + self.metadata = metadata + + @abstractmethod + def get_images(self, data: Any, column: str) -> Set[str]: + raise NotImplementedError + + @abstractmethod + def filter_data(self, data: Any, column: str, set_name: str) -> Set[str]: + raise NotImplementedError + + @abstractmethod + def select_subset_from_images( + self, data: Any, image_set: Set[str], column: str + ) -> Any: + raise NotImplementedError + + @abstractmethod + def get_columns(sel, data: Any, columns: List[str]) -> Any: + raise NotImplementedError + + @staticmethod + @abstractstaticmethod + def to_list(data: Any) -> List: + raise NotImplementedError + + +class HTFNumpyEngine(HTFEngine): + FOR_TYPE = np.ndarray + + def get_images(self, data: np.ndarray, column: str) -> Set[str]: + images: Set[str] = set(data[self.metadata.label_mapper[column]].tolist()) + return images + + def filter_data(self, data: np.ndarray, column: str, set_name: str) -> np.ndarray: + mask = data[:, self.metadata.label_mapper[column]] == set_name + filtered_data = data[mask, :] + return filtered_data + + def select_subset_from_images( + self, data: np.ndarray, image_set: Set[str], column: str + ) -> np.ndarray: + indices = np.isin(data[:, self.metadata.label_mapper[column]], image_set) + return data[indices] + + def get_columns(self, data: np.ndarray, columns: List[str]) -> np.ndarray: + idx_columns = [self.metadata.label_mapper[col] for col in columns] + return data[:, idx_columns] + + @staticmethod + def to_list(data: np.ndarray) -> List: + return data.tolist() + + +class HTFPandasEngine(HTFEngine): + FOR_TYPE = pd.DataFrame + + def get_images(self, data: pd.DataFrame, column: str) -> Set[str]: + images: Set[str] = set(data[column].tolist()) + return images + + def filter_data( + self, data: pd.DataFrame, column: str, set_name: str + ) -> pd.DataFrame: + return data.query(f"{column} == '{set_name}'") + + def select_subset_from_images( + self, data: pd.DataFrame, image_set: Set[str], column: str + ) -> pd.DataFrame: + return data[data[column].isin(image_set)] + + def get_columns(self, data: pd.DataFrame, columns: List[str]) -> pd.DataFrame: + return data[columns] + + @staticmethod + def to_list(data: pd.DataFrame) -> List: + return data.values.tolist() + + +class HTFTensorflowEngine(HTFEngine): + FOR_TYPE = tf.Tensor + + def get_images(self, data: tf.Tensor, column: str) -> Set[str]: + raise NotImplementedError + + def filter_data(self, data: tf.Tensor, column: str, set_name: str) -> tf.Tensor: + raise NotImplementedError + + def select_subset_from_images( + self, data: tf.Tensor, image_set: Set[str], column: str + ) -> tf.Tensor: + raise NotImplementedError + + def get_columns(self, data: tf.Tensor, columns: List[str]) -> tf.Tensor: + raise NotImplementedError + + @staticmethod + def to_list(data: tf.Tensor) -> List: + raise NotImplementedError + + +ENGINES: Dict[Type, Type[HTFEngine]] = { + np.ndarray: HTFNumpyEngine, + pd.DataFrame: HTFPandasEngine, + tf.Tensor: HTFTensorflowEngine, +} diff --git a/hourglass_tensorflow/handlers/meta.py b/hourglass_tensorflow/handlers/meta.py new file mode 100644 index 0000000..2dc3c4e --- /dev/null +++ b/hourglass_tensorflow/handlers/meta.py @@ -0,0 +1,54 @@ +from abc import ABC +from abc import abstractmethod + +from hourglass_tensorflow.types.config.fields import HTFConfigField +from hourglass_tensorflow.utils.object_logger import ObjectLogger +from hourglass_tensorflow.types.config.metadata import HTFMetadata + + +class _HTFHandler(ABC, ObjectLogger): + def __init__( + self, + config: HTFConfigField, + metadata: HTFMetadata = None, + verbose: bool = True, + *args, + **kwargs, + ) -> None: + super().__init__(verbose=verbose) + self._config = config + self._metadata = metadata if metadata is not None else HTFMetadata() + self._executed = False + self.init_handler(*args, **kwargs) + + def __call__(self, *args, **kwargs) -> None: + if not self._executed: + self.run(*args, **kwargs) + self.executed = True + else: + self.warning( + f"This {self.__class__.__name__} has already been executed. Use self.reset" + ) + + def __repr__(self) -> str: + return f"" + + @property + def config(self) -> HTFConfigField: + return self._config + + @property + def meta(self) -> HTFMetadata: + return self._metadata + + def init_handler(self, *args, **kwargs) -> None: + pass + + def reset(self, *args, **kwargs) -> "_HTFHandler": + return self.__class__( + config=self.config, verbose=self._verbose, *args, **kwargs + ) + + @abstractmethod + def run(self): + raise NotImplementedError diff --git a/hourglass_tensorflow/handlers/model.py b/hourglass_tensorflow/handlers/model.py new file mode 100644 index 0000000..e69de29 diff --git a/hourglass_tensorflow/handlers/train.py b/hourglass_tensorflow/handlers/train.py new file mode 100644 index 0000000..e69de29