diff --git a/docs/imgs/TransformerBlocks.png b/docs/imgs/TransformerBlocks.png new file mode 100644 index 0000000..7a619ff Binary files /dev/null and b/docs/imgs/TransformerBlocks.png differ diff --git a/kdp/custom_layers.py b/kdp/custom_layers.py index 2e2a0f1..a887fb3 100644 --- a/kdp/custom_layers.py +++ b/kdp/custom_layers.py @@ -81,3 +81,70 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor: """ output = tf.cast(inputs, tf.float32) return output + + +class TransformerBlock(tf.keras.layers.Layer): + """Class that implements a transformer block.""" + + def __init__( + self, + dim_model: int = 32, + num_heads: int = 3, + ff_units: int = 16, + dropout_rate: float = 0.2, + **kwargs, + ): + """Initializes the transformer block. + + Args: + dim_model (int): Dimension of the model. + num_heads (int): Number of attention heads. + ff_units (int): Units in the feed-forward layer. + dropout_rate (float): Dropout rate to apply. + kwargs: Additional keyword arguments. + """ + super().__init__(**kwargs) + self.d_model = dim_model + self.num_heads = num_heads + self.ff_units = ff_units + self.dropout_rate = dropout_rate + + # Define layers + self.multihead_attention = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=dim_model) + self.dropout1 = tf.keras.layers.Dropout(dropout_rate) + self.add1 = tf.keras.layers.Add() + self.layer_norm1 = tf.keras.layers.LayerNormalization() + + self.ff1 = tf.keras.layers.Dense(ff_units, activation="relu") + self.dropout2 = tf.keras.layers.Dropout(dropout_rate) + self.ff2 = tf.keras.layers.Dense(dim_model) + self.add2 = tf.keras.layers.Add() + self.layer_norm2 = tf.keras.layers.LayerNormalization() + + def call(self, inputs: tf.Tensor) -> tf.Tensor: + """Defines the forward pass for the transformer block. + + Args: + inputs (tf.Tensor): Input tensor for the block. + + Returns: + tf.Tensor: Output tensor after processing. + """ + # Reshape if needed + if len(inputs.shape) == 2: + inputs = tf.expand_dims(inputs, axis=1) + + # Multi-head attention + attention = self.multihead_attention(inputs, inputs) + attention = self.dropout1(attention) + attention = self.add1([inputs, attention]) + attention_norm = self.layer_norm1(attention) + + # Feed-forward layers + ff = self.ff1(attention_norm) + ff = self.dropout2(ff) + ff = self.ff2(ff) + ff = self.add2([attention_norm, ff]) + ff_norm = self.layer_norm2(ff) + + return ff_norm diff --git a/kdp/layers_factory.py b/kdp/layers_factory.py index 3019368..984fd7c 100644 --- a/kdp/layers_factory.py +++ b/kdp/layers_factory.py @@ -2,7 +2,7 @@ import tensorflow as tf -from kdp.custom_layers import CastToFloat32Layer, TextPreprocessingLayer +from kdp.custom_layers import CastToFloat32Layer, TextPreprocessingLayer, TransformerBlock class PreprocessorLayerFactory: @@ -252,3 +252,20 @@ def cast_to_float32_layer(name: str = "cast_to_float32", **kwargs: dict) -> tf.k name=name, **kwargs, ) + + @staticmethod + def transformer_block_layer(name: str = "transformer", **kwargs: dict) -> tf.keras.layers.Layer: + """Create a TransformerBlock layer. + + Args: + name: The name of the layer. + **kwargs: Additional keyword arguments to pass to the layer constructor. + + Returns: + An instance of the TransformerBlock layer. + """ + return PreprocessorLayerFactory.create_layer( + layer_class=TransformerBlock, + name=name, + **kwargs, + ) diff --git a/kdp/processor.py b/kdp/processor.py index b38e970..82d8494 100644 --- a/kdp/processor.py +++ b/kdp/processor.py @@ -106,6 +106,10 @@ def __init__( overwrite_stats: bool = False, log_to_file: bool = False, features_specs: dict[str, FeatureType | str] = None, + transfo_nr_blocks: int = None, + transfo_nr_heads: int = 3, + transfo_ff_units: int = 16, + transfo_dropout_rate: float = 0.25, ) -> None: """Initialize a preprocessing model. @@ -121,6 +125,11 @@ def __init__( overwrite_stats (bool): A boolean indicating whether to overwrite the statistics. log_to_file (bool): A boolean indicating whether to log to a file. features_specs (dict[str, FeatureType | str]): A dictionary containing the features and their types. + transfo_nr_blocks (int): The number of transformer blocks for the transformer block + (default=None, transformer block is disabled). + transfo_nr_heads (int): The number of heads for the transformer block (categorical variables). + transfo_ff_units (int): The number of feed forward units for the transformer + transfo_dropout_rate (float): The dropout rate for the transformer block (default=0.25). """ self.path_data = path_data self.batch_size = batch_size or 50_000 @@ -130,12 +139,18 @@ def __init__( self.feature_crosses = feature_crosses or [] self.output_mode = output_mode self.overwrite_stats = overwrite_stats + # transformer blocks controll + self.transfo_nr_blocks = transfo_nr_blocks + self.transfo_nr_heads = transfo_nr_heads + self.transfo_ff_units = transfo_ff_units + self.transfo_dropout_rate = transfo_dropout_rate # PLACEHOLDERS self.preprocessors = {} self.inputs = {} self.signature = {} self.outputs = {} + self.outputs_categorical = {} if log_to_file: logger.info("Logging to file enabled 🗂️") @@ -322,11 +337,6 @@ def _add_pipeline_numeric(self, feature_name: str, input_layer, stats: dict) -> # defining the pipeline input layer _output_pipeline = preprocessor.chain(input_layer=input_layer) - # adjusting output - # if _feature.feature_type == FeatureType.FLOAT_DISCRETIZED: - # Cast the crossed feature to float32 - # _output_pipeline = tf.cast(_output_pipeline, tf.float32) - # defining output self.outputs[feature_name] = _output_pipeline @@ -402,8 +412,9 @@ def _add_pipeline_categorical(self, feature_name: str, input_layer, stats: dict) layer_creator=PreprocessorLayerFactory.flatten_layer, name=f"flatten_{feature_name}", ) + # adding outputs - self.outputs[feature_name] = preprocessor.chain(input_layer=input_layer) + self.outputs_categorical[feature_name] = preprocessor.chain(input_layer=input_layer) def _add_pipeline_text(self, feature_name: str, input_layer, stats: dict) -> None: """Add a text preprocessing step to the pipeline. @@ -455,7 +466,9 @@ def _add_pipeline_text(self, feature_name: str, input_layer, stats: dict) -> Non layer_creator=PreprocessorLayerFactory.cast_to_float32_layer, name=f"cast_to_float_{feature_name}", ) - self.outputs[feature_name] = preprocessor.chain(input_layer=input_layer) + # adding outputs + self.outputs_categorical[feature_name] = preprocessor.chain(input_layer=input_layer) + # self.outputs[feature_name] = preprocessor.chain(input_layer=input_layer) def _add_pipeline_cross(self) -> None: """Add a crossing preprocessing step to the pipeline. @@ -500,9 +513,41 @@ def _prepare_outputs(self) -> None: """ logger.info("Building preprocessor Model") if self.output_mode == OutputModeOptions.CONCAT: + # getting all features to concatenate self.features_to_concat = list(self.outputs.values()) - self.concat = tf.keras.layers.Concatenate(axis=-1) - self.outputs = self.concat(self.features_to_concat) + self.features_cat_to_concat = list(self.outputs_categorical.values()) + + # Concatenate numerical features + concat_num = tf.keras.layers.Concatenate( + name="ConcatenateNumeric", + axis=-1, + )(self.features_to_concat) + + # Concatenate categorical features + concat_cat = tf.keras.layers.Concatenate( + name="ConcatenateCategorical", + axis=-1, + )(self.features_cat_to_concat) + + # adding transformer layers + if self.transfo_nr_blocks: + logger.info(f"Adding transformer blocks: #{self.transfo_nr_blocks}") + for block_idx in range(self.transfo_nr_blocks): + concat_cat = PreprocessorLayerFactory.transformer_block_layer( + dim_model=concat_cat.shape[1], + num_heads=self.transfo_nr_heads, + ff_units=self.transfo_ff_units, + dropout_rate=self.transfo_dropout_rate, + name=f"transformer_block_{block_idx}_{self.transfo_nr_heads}heads", + )(concat_cat) + + # Combine concatenated numerical and categorical features + self.outputs = tf.keras.layers.Concatenate( + name="ConcatenateAllFeatures", + axis=-1, + )([concat_num, concat_cat]) + + # self.outputs = self.concat(self.features_to_concat + [self.concat]) logger.info("Concatenating outputs mode enabled") else: outputs = OrderedDict([(k, None) for k in self.inputs if k in self.outputs]) diff --git a/pyproject.toml b/pyproject.toml index f4aaf6b..1977bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "kdp" -version = "1.4.0" +version = "1.5.1" documentation = "http://piotrlaczkowski.github.io/keras-data-processor/" repository = "https://github.com/piotrlaczkowski/keras-data-processor" description = "Data Preprocessing model based on Keras preprocessing layers"