diff --git a/tests/popmon/base/test_module.py b/tests/popmon/base/test_module.py index c5322b73..abd91728 100644 --- a/tests/popmon/base/test_module.py +++ b/tests/popmon/base/test_module.py @@ -1,32 +1,41 @@ import numpy as np +import pytest from popmon.base import Module -def test_popmon_module(): - class Scaler(Module): - _input_keys = ("input_key",) - _output_keys = ("output_key",) +class Scaler(Module): + _input_keys = ("input_key",) + _output_keys = ("output_key",) - def __init__(self, input_key, output_key, mean, std): - super().__init__() - self.input_key = input_key - self.output_key = output_key - self.mean = mean - self.std = std + def __init__(self, input_key, output_key, mean, std): + super().__init__() + self.input_key = input_key + self.output_key = output_key + self.mean = mean + self.std = std - def transform(self, input_array: np.ndarray): - res = input_array - np.mean(input_array) - res = res / np.std(res) - res = res * self.std - res = res + self.mean - return res + def transform(self, input_array: np.ndarray): + res = input_array - np.mean(input_array) + res = res / np.std(res) + res = res * self.std + res = res + self.mean + return res - test_module = Scaler(input_key="x", output_key="scaled_x", mean=2.0, std=0.3) +@pytest.fixture +def test_module(): + return Scaler(input_key="x", output_key="scaled_x", mean=2.0, std=0.3) + + +def test_popmon_module(test_module): datastore = {"x": np.arange(10)} datastore = test_module.transform(datastore) assert "x" in datastore # check if key 'x' is still in the datastore np.testing.assert_almost_equal(np.mean(datastore["scaled_x"]), 2.0, decimal=5) np.testing.assert_almost_equal(np.std(datastore["scaled_x"]), 0.3, decimal=5) + + +def test_popmon_module_repr(test_module): + assert str(test_module) == "Scaler(input_key='x', output_key='scaled_x')" diff --git a/tests/popmon/base/test_pipeline.py b/tests/popmon/base/test_pipeline.py index 79c22908..12afc673 100644 --- a/tests/popmon/base/test_pipeline.py +++ b/tests/popmon/base/test_pipeline.py @@ -1,6 +1,7 @@ import logging import numpy as np +import pytest from popmon.base import Module, Pipeline @@ -65,16 +66,12 @@ def transform(self, input_array: np.ndarray, weights: np.ndarray): return result -def test_popmon_pipeline(): +@pytest.fixture +def test_pipeline(): logger = logging.getLogger() logger.addHandler(logging.StreamHandler()) logger.setLevel(logging.INFO) - datastore = {"x": np.array([7, 2, 7, 9, 6]), "weights": np.array([1, 1, 2, 1, 2])} - expected_result = np.sum( - np.power(np.log(datastore["x"]), 2) * datastore["weights"] - ) / np.sum(datastore["weights"]) - log_pow_pipeline = Pipeline( modules=[ LogTransformer(input_key="x", output_key="log_x"), @@ -92,5 +89,20 @@ def test_popmon_pipeline(): ], logger=logger, ) + return pipeline + + +def test_popmon_pipeline(test_pipeline): + datastore = {"x": np.array([7, 2, 7, 9, 6]), "weights": np.array([1, 1, 2, 1, 2])} + expected_result = np.sum( + np.power(np.log(datastore["x"]), 2) * datastore["weights"] + ) / np.sum(datastore["weights"]) - assert pipeline.transform(datastore)["res"] == expected_result + assert test_pipeline.transform(datastore)["res"] == expected_result + + +def test_pipeline_repr(test_pipeline): + assert ( + str(test_pipeline) + == """Pipeline: [\n\tPipeline: [\n\t\tLogTransformer(input_key='x', output_key='log_x')\n\t\tPowerTransformer(input_key='log_x', output_key='log_pow_x')\n\t]\n\tSumNormalizer(input_key='weights', output_key='norm_weights')\n\tWeightedSum(input_key='log_pow_x', weight_key='norm_weights', output_key='res')\n]""" + )