Skip to content

Commit

Permalink
feat(KDP): adding TransformerBlocks to all Features Options
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrlaczkowski committed Apr 30, 2024
1 parent 0338fb4 commit e27b0e5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
Binary file added docs/imgs/TransformerBlockAllFeatures.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
29 changes: 26 additions & 3 deletions kdp/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class TextVectorizerOutputOptions(auto):
MULTI_HOT = "multi_hot"


class TransformerBlockPlacementOptions(auto):
CATEGORICAL = "categorical"
ALL_FEATURES = "all_features"


# testing conversion
class FeatureSpaceConverter:
def __init__(self):
Expand Down Expand Up @@ -110,6 +115,7 @@ def __init__(
transfo_nr_heads: int = 3,
transfo_ff_units: int = 16,
transfo_dropout_rate: float = 0.25,
transfo_placement: str = TransformerBlockPlacementOptions.CATEGORICAL,
) -> None:
"""Initialize a preprocessing model.
Expand All @@ -130,6 +136,7 @@ def __init__(
transfo_nr_heads (int): The number of heads for the transformer block (categorical variables).
transfo_ff_units (int): The number of feed forward units for the transformer
transfo_dropout_rate (float): The dropout rate for the transformer block (default=0.25).
transfo_placement (str): The placement of the transformer block (categorical | all_features).
"""
self.path_data = path_data
self.batch_size = batch_size or 50_000
Expand All @@ -144,6 +151,7 @@ def __init__(
self.transfo_nr_heads = transfo_nr_heads
self.transfo_ff_units = transfo_ff_units
self.transfo_dropout_rate = transfo_dropout_rate
self.transfo_placement = transfo_placement

# PLACEHOLDERS
self.preprocessors = {}
Expand Down Expand Up @@ -530,8 +538,8 @@ def _prepare_outputs(self) -> None:
)(self.features_cat_to_concat)

# adding transformer layers
if self.transfo_nr_blocks:
logger.info(f"Adding transformer blocks: #{self.transfo_nr_blocks}")
if self.transfo_nr_blocks and self.transfo_placement == TransformerBlockPlacementOptions.CATEGORICAL:
logger.info(f"Adding transformer blocks CATEGORICAL: #{self.transfo_nr_blocks}")
for block_idx in range(self.transfo_nr_blocks):
concat_cat = PreprocessorLayerFactory.transformer_block_layer(
dim_model=concat_cat.shape[1],
Expand All @@ -542,12 +550,27 @@ def _prepare_outputs(self) -> None:
)(concat_cat)

# Combine concatenated numerical and categorical features
logger.info("Concatenating all features")
self.outputs = tf.keras.layers.Concatenate(
name="ConcatenateAllFeatures",
axis=-1,
)([concat_num, concat_cat])

# self.outputs = self.concat(self.features_to_concat + [self.concat])
# TODO: check shape mismatch here (Inputs have incompatible shapes. Received shapes (1, 141) and (1, 4))
if self.transfo_nr_blocks and self.transfo_placement == TransformerBlockPlacementOptions.ALL_FEATURES:
_transfor_input_shape = self.outputs.shape[1]
logger.info(
f"Adding transformer blocks ALL_FEATURES: #{self.transfo_nr_blocks}",
)
for block_idx in range(self.transfo_nr_blocks):
self.outputs = PreprocessorLayerFactory.transformer_block_layer(
dim_model=_transfor_input_shape,
num_heads=self.transfo_nr_heads,
ff_units=self.transfo_ff_units,
dropout_rate=self.transfo_dropout_rate,
name=f"transformer_block_{block_idx}_{self.transfo_nr_heads}heads",
)(self.outputs)

logger.info("Concatenating outputs mode enabled")
else:
outputs = OrderedDict([(k, None) for k in self.inputs if k in self.outputs])
Expand Down

0 comments on commit e27b0e5

Please sign in to comment.