Skip to content

Commit

Permalink
[FIX] Fix DatasetHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 19, 2022
1 parent 52ec98d commit 0f15c43
Showing 1 changed file with 127 additions and 40 deletions.
167 changes: 127 additions & 40 deletions hourglass_tensorflow/datasets/_meta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import Set
from typing import Dict
Expand All @@ -7,7 +8,6 @@
from typing import Union
from typing import Iterable
from typing import Optional
from posixpath import split

import numpy as np
import pandas as pd
Expand All @@ -16,6 +16,7 @@

from hourglass_tensorflow._errors import BadConfigurationError
from hourglass_tensorflow.utils.sets import split_train_test
from hourglass_tensorflow.utils.object_logger import ObjectLogger

if TYPE_CHECKING:
from hourglass_tensorflow.utils.config import HTFConfiguration
Expand Down Expand Up @@ -45,7 +46,7 @@ class Config:
def global_config_required(method):
def wrapper(self: "HTFBaseDatasetHandler", *args, **kwargs):
if self.global_config is not None:
method(self, *args, **kwargs)
return method(self, *args, **kwargs)
else:
raise AttributeError(
f"{self} has no global configuration of type <HTFConfiguration>"
Expand All @@ -54,12 +55,13 @@ def wrapper(self: "HTFBaseDatasetHandler", *args, **kwargs):
return wrapper


class HTFBaseDatasetHandler(ABC):
class HTFBaseDatasetHandler(ABC, ObjectLogger):
def __init__(
self,
dataset: CastableTableDataset,
config: "HTFDatasetConfig",
global_config: Optional["HTFConfiguration"] = None,
verbose: bool = True,
**kwargs,
) -> None:
# Init Data
Expand All @@ -68,13 +70,27 @@ def __init__(
self._test_set: Optional[CastableTableDataset] = None
self._train_set: Optional[CastableTableDataset] = None
self._validation_set: Optional[CastableTableDataset] = None
self._test_dataset: Optional[tf.data.Dataset] = None
self._train_dataset: Optional[tf.data.Dataset] = None
self._validation_dataset: Optional[tf.data.Dataset] = None
self._config = config
self._metadata = HTFBaseDatasetHandlerMetadata()
self.global_config = global_config
self._verbose = True
self.kwargs = kwargs
# Booleans
self.checked = self._check_config(_error=False)
self.splitted = False
# Launch Init
self.init_parameters()
# Apply Checks
self.check_config()
self._check_config()
# Apply Execution
self.execute()

def init_parameters(self, **kwargs) -> None:
pass

@property
def metadata(self) -> HTFBaseDatasetHandlerMetadata:
return self._metadata
Expand All @@ -87,14 +103,6 @@ def config(self) -> "HTFDatasetConfig":
def split(self) -> "HTFDatasetSplitConfig":
return self._config.split

@property
def dataset_is_numpy(self) -> bool:
return isinstance(self._data, np.ndarray)

@property
def dataset_is_pandas(self) -> bool:
return isinstance(self._data, pd.DataFrame)

@property
def sets(self) -> "HTFDatasetSetsConfig":
return self.config.sets
Expand Down Expand Up @@ -123,6 +131,50 @@ def has_test(self) -> bool:
def has_validation(self) -> bool:
return self.sets.validation

@property
def CHECK_CONDITIONS(self) -> List[bool]:
return [
sum(
[self.config.split.train_ratio]
+ [
getattr(self.config.split, f"{s}_ratio")
for s, m in self.config.sets.__fields__.items()
if getattr(self.config.sets, s)
]
)
== 1 # Check that activated set ratios sums to 1
]

def _check_config(self, _error: bool = True):
validity = all(self.CHECK_CONDITIONS)
if not validity and _error:
raise BadConfigurationError("Dataset configuration is incorrect")
return validity

@abstractmethod
def execute(self) -> None:
raise NotImplementedError


class HTFDatasetHandler(HTFBaseDatasetHandler):
def __init__(
self,
dataset: CastableTableDataset,
config: "HTFDatasetConfig",
global_config: Optional["HTFConfiguration"] = None,
verbose: bool = True,
**kwargs,
) -> None:
super().__init__(dataset, config, global_config, verbose, **kwargs)

@property
def dataset_is_numpy(self) -> bool:
return isinstance(self._data, np.ndarray)

@property
def dataset_is_pandas(self) -> bool:
return isinstance(self._data, pd.DataFrame)

@property
def split_column_name(self) -> str:
return self.split.column
Expand All @@ -147,22 +199,7 @@ def image_column_index(self) -> str:
def label_mapper(self) -> Dict[str, int]:
return self.global_config._metadata.label_mapper

def check_config(self, _error: bool = True):
conditions = [
sum(
[self.config.split.train_ratio]
+ [
getattr(self.config.split, f"{s}_ratio")
for s, m in self.config.sets.__fields__.items()
if getattr(self.config.sets, s)
]
)
== 1 # Check that activated set ratios sums to 1
]
validity = all(conditions)
if _error:
raise BadConfigurationError("Dataset configuration is incorrect")
return validity
# Split Train/Test/Validation Methods

@global_config_required
def _get_images(self) -> Set[str]:
Expand Down Expand Up @@ -195,11 +232,13 @@ def _generate_image_set(self) -> ImageSets:
if has_test & has_validation:
# + Validation and Test
train, test = split_train_test(images, train_ratio + validation_ratio)
train, validation = split_train_test(train, train_ratio)
train, validation = split_train_test(
train, train_ratio / (train_ratio + validation_ratio)
)
elif has_validation:
train, validation = split_train_test(train, train_ratio)
train, validation = split_train_test(images, train_ratio)
elif has_test:
train, test = split_train_test(train, train_ratio)
train, test = split_train_test(images, train_ratio)
else:
train = images
else:
Expand All @@ -215,7 +254,7 @@ def _generate_set_from_pandas(self, image_set: Set[str]) -> pd.DataFrame:
return self._data[self._data[self.image_column_name].isin(image_set)]

def _generate_set_from_numpy(self, image_set: Set[str]):
indices = np.isin(self._data[:, self.image_column_index], image_set)
indices = np.isin(self._data[:, self.image_column_index], list(image_set))
return self._data[indices]

@global_config_required
Expand All @@ -242,13 +281,13 @@ def _execute_split(self) -> ReturnTableSets:
def _execute_selection(self) -> ReturnTableSets:
if self.dataset_is_pandas:
train = self._data.query(
f"{self.split_column_name} == {self.split.train_value}"
f"{self.split_column_name} == '{self.split.train_value}'"
)
test = self._data.query(
f"{self.split_column_name} == {self.split.test_value}"
f"{self.split_column_name} == '{self.split.test_value}'"
)
validation = self._data.query(
f"{self.split_column_name} == {self.split.validation_value}"
f"{self.split_column_name} == '{self.split.validation_value}'"
)
elif self.dataset_is_numpy:
train_mask = (
Expand Down Expand Up @@ -278,11 +317,59 @@ def split_sets(self) -> None:
self._train_set = train
self._test_set = test
self._validation_set = validation
self.splitted = True

def _get_joint_columns(self) -> List[str]:
joints = self.global_config.config.data.output.joints
num_joint = joints.n
naming = joints.naming_convention
return [
naming.format(JOINT_ID=i, SUFFIX=suffix)
for i in range(num_joint)
for suffix in joints.format.suffix.__dict__.values()
]

def extract_data_groups(
self, dataset: CastableTableDataset
) -> Tuple[List, List, List, List]:
# Extract columns
coord_columns = [self.label_mapper[col] for col in self._get_joint_columns()]
bbox_columns = [self.label_mapper[col] for col in self.config.bbox.cols]
center_columns = [self.label_mapper[col] for col in self.config.center.cols]
if self.dataset_is_pandas:
array = dataset.to_numpy()
else:
array = dataset
filenames = array[:, self.image_column_index].tolist()
coordinates = array[:, coord_columns].tolist()
bounding_boxes = array[:, bbox_columns].tolist()
centers = array[:, center_columns].tolist()
return filenames, coordinates, bounding_boxes, centers

# Generate Datasets

def _create_dataset(self, dataset: CastableTableDataset) -> tf.data.Dataset:
# Generate
return (
tf.data.Dataset.from_tensor_slices(
# Generate from slices
self.extract_data_groups(dataset=dataset)
)
.map(
# Load Images
lambda x: x
)
.map(
# Compute BBOX cropping
lambda x: x
)
.map(
# Compute Heatmaps
lambda x: x
)
)

# Main Execution method
def execute(self) -> None:
self.split_sets()


class HTFDatasetHandler(HTFBaseDatasetHandler):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# pass

0 comments on commit 0f15c43

Please sign in to comment.