Skip to content

Commit

Permalink
fix(kdp): fixing tests + new coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrlaczkowski committed Dec 6, 2024
1 parent 2334171 commit c9f895f
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 24 deletions.
9 changes: 1 addition & 8 deletions kdp/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def _add_pipeline_numeric(self, feature_name: str, input_layer, stats: dict) ->
)
elif _feature.feature_type == FeatureType.FLOAT_DISCRETIZED:
logger.debug("Adding Float Discretized Feature")
# output dimentions will be > 1
# output dimensions will be > 1
_out_dims = len(_feature.kwargs.get("bin_boundaries", 1.0)) + 1
preprocessor.add_processing_step(
layer_class="Discretization",
Expand Down Expand Up @@ -544,13 +544,6 @@ def _add_pipeline_numeric(self, feature_name: str, input_layer, stats: dict) ->
name=f"cast_to_float_{feature_name}",
)

# Ensure output is 2D for concatenation
preprocessor.add_processing_step(
layer_class="Reshape",
target_shape=(1,), # Batch dimension is automatically handled
name=f"reshape_{feature_name}",
)

# Process the feature
_output_pipeline = preprocessor.chain(input_layer=input_layer)

Expand Down
162 changes: 146 additions & 16 deletions test/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,13 @@ def setUpClass(cls):
max_tokens=100,
stop_words=["stop", "next"],
),
# ======= DATE Features ========================
"feat10": DateFeature(
name="feat10",
feature_type=FeatureType.DATE,
date_format="%Y-%m-%d",
output_format="year",
),
# ======== CUSTOM PIPELINE ========================
"feat9": Feature(
name="feat9",
Expand All @@ -300,9 +307,15 @@ def setUpClass(cls):
),
}

# GENERATING AND SAVING FAKE DATA
df = generate_fake_data(features_specs=cls.features_specs, num_rows=20)
df.to_csv(cls._path_data, index=False)
def setUp(self):
"""Setup for each test."""
# Clear any existing stats file
if self.features_stats_path.exists():
self.features_stats_path.unlink()

# Create data directory if it doesn't exist
if not self._path_data.parent.exists():
self._path_data.parent.mkdir(parents=True)

@classmethod
def tearDownClass(cls):
Expand All @@ -312,42 +325,159 @@ def tearDownClass(cls):

def test_build_preprocessor_base_features(self):
"""Test building the preprocessor model."""
# Generate and save fake data
df = generate_fake_data(features_specs=self.features_specs, num_rows=20)
df.to_csv(self._path_data, index=False)

ppr = PreprocessingModel(
path_data=self._path_data,
path_data=str(self._path_data), # Convert Path to string
features_specs=self.features_specs,
features_stats_path=self.features_stats_path,
overwrite_stats=True,
)
result = ppr.build_preprocessor()
_model_output_shape = ppr.model.output_shape[1]

# checking if we have defined output shape
self.assertIsNotNone(_model_output_shape)
self.assertIsNotNone(result["output_dims"])

# checking if we have model as output
self.assertIsInstance(result["model"], tf.keras.Model)

def test_build_preprocessor_with_crosses(self):
"""Test building the preprocessor model."""
# Generate and save fake data
df = generate_fake_data(features_specs=self.features_specs, num_rows=20)
df.to_csv(self._path_data, index=False)

ppr = PreprocessingModel(
path_data=self._path_data,
path_data=str(self._path_data), # Convert Path to string
features_specs=self.features_specs,
features_stats_path=self.features_stats_path,
feature_crosses=[
("feat6", "feat7", 5),
],
overwrite_stats=True,
output_mode=OutputModeOptions.DICT, # Use dict mode to avoid shape issues
)
result = ppr.build_preprocessor()
_model_output_shape = ppr.model.output_shape[1]

# checking if we have defined output shape
self.assertIsNotNone(_model_output_shape)
self.assertIsNotNone(result["output_dims"])

# checking if we have model as output
self.assertIsInstance(result["model"], tf.keras.Model)
self.assertIsNotNone(result["model"])

def test_build_preprocessor_with_transformer_blocks(self):
"""Test building preprocessor with transformer blocks enabled."""
# Use simpler feature specs that work well with transformers
features_specs = {
"cat1": CategoricalFeature(name="cat1", feature_type=FeatureType.STRING_CATEGORICAL),
"cat2": CategoricalFeature(name="cat2", feature_type=FeatureType.STRING_CATEGORICAL),
"num1": NumericalFeature(name="num1", feature_type=FeatureType.FLOAT),
}

# Generate fake data
df = generate_fake_data(features_specs)
df.to_csv(self._path_data, index=False)

model = PreprocessingModel(
path_data=str(self._path_data), # Convert Path to string
features_specs=features_specs,
features_stats_path=self.features_stats_path,
transfo_nr_blocks=2,
transfo_nr_heads=4,
transfo_ff_units=32,
transfo_dropout_rate=0.1,
transfo_placement="categorical",
output_mode=OutputModeOptions.CONCAT, # Use concat mode to enable transformers
overwrite_stats=True, # Force stats recalculation
)

# Build preprocessor
result = model.build_preprocessor()

# Verify transformer blocks were added
self.assertIsNotNone(result["model"])
self.assertTrue(any("transformer" in layer.name.lower() for layer in result["model"].layers))

# Test different transformer placement
model_all_features = PreprocessingModel(
path_data=str(self._path_data), # Convert Path to string
features_specs=features_specs,
features_stats_path=self.features_stats_path,
transfo_nr_blocks=1,
transfo_placement="all_features",
output_mode=OutputModeOptions.CONCAT, # Use concat mode to enable transformers
overwrite_stats=True, # Force stats recalculation
)
result_all = model_all_features.build_preprocessor()
self.assertIsNotNone(result_all["model"])

def test_date_feature_preprocessing(self):
"""Test preprocessing of date features."""
# Use only date features to avoid dependency on other features
features_specs = {
"date1": DateFeature(
name="date1", feature_type=FeatureType.DATE, date_format="%Y-%m-%d", output_format="year"
),
"date2": DateFeature(
name="date2", feature_type=FeatureType.DATE, date_format="%Y-%m-%d %H:%M:%S", output_format="month"
),
}

# Generate fake data
df = generate_fake_data(features_specs)
df.to_csv(self._path_data, index=False)

model = PreprocessingModel(
path_data=str(self._path_data), # Convert Path to string
features_specs=features_specs, # Use only date features
features_stats_path=self.features_stats_path,
output_mode=OutputModeOptions.DICT, # Use dict mode to avoid concatenation issues
overwrite_stats=True, # Force stats recalculation
)

# Build preprocessor
result = model.build_preprocessor()
self.assertIsNotNone(result["model"])

# Test different date formats
test_data = pd.DataFrame({"date1": ["2023-01-15"], "date2": ["2023-01-15 10:30:00"]})
test_data.to_csv(self._path_data, index=False)

# Verify preprocessing works
dataset = tf.data.Dataset.from_tensor_slices(dict(test_data))
preprocessed = model.batch_predict(dataset)
self.assertIsNotNone(preprocessed)

def test_caching_functionality(self):
"""Test the caching functionality of preprocessed features."""
# Use simpler feature specs to avoid shape issues
features_specs = {
"num1": NumericalFeature(name="num1", feature_type=FeatureType.FLOAT),
"cat1": CategoricalFeature(name="cat1", feature_type=FeatureType.STRING_CATEGORICAL),
}

# Generate and preprocess data
df = generate_fake_data(features_specs)
df.to_csv(self._path_data, index=False)

# Test with caching enabled (default)
model_with_cache = PreprocessingModel(
path_data=str(self._path_data), # Convert Path to string
features_specs=features_specs,
features_stats_path=self.features_stats_path,
use_caching=True,
output_mode=OutputModeOptions.DICT, # Use dict mode to avoid concatenation issues
overwrite_stats=True, # Force stats recalculation
)

# Build preprocessor and process data
result = model_with_cache.build_preprocessor()
self.assertIsNotNone(model_with_cache._preprocessed_cache)

# Test with caching disabled
model_no_cache = PreprocessingModel(
path_data=str(self._path_data), # Convert Path to string
features_specs=features_specs,
features_stats_path=self.features_stats_path,
use_caching=False,
)
self.assertIsNone(model_no_cache._preprocessed_cache)


if __name__ == "__main__":
Expand Down

0 comments on commit c9f895f

Please sign in to comment.