Skip to content

Commit

Permalink
test: string representation base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrugman committed Jun 28, 2022
1 parent 5e369ff commit 1532239
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
43 changes: 26 additions & 17 deletions tests/popmon/base/test_module.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,41 @@
import numpy as np
import pytest

from popmon.base import Module


def test_popmon_module():
class Scaler(Module):
_input_keys = ("input_key",)
_output_keys = ("output_key",)
class Scaler(Module):
_input_keys = ("input_key",)
_output_keys = ("output_key",)

def __init__(self, input_key, output_key, mean, std):
super().__init__()
self.input_key = input_key
self.output_key = output_key
self.mean = mean
self.std = std
def __init__(self, input_key, output_key, mean, std):
super().__init__()
self.input_key = input_key
self.output_key = output_key
self.mean = mean
self.std = std

def transform(self, input_array: np.ndarray):
res = input_array - np.mean(input_array)
res = res / np.std(res)
res = res * self.std
res = res + self.mean
return res
def transform(self, input_array: np.ndarray):
res = input_array - np.mean(input_array)
res = res / np.std(res)
res = res * self.std
res = res + self.mean
return res

test_module = Scaler(input_key="x", output_key="scaled_x", mean=2.0, std=0.3)

@pytest.fixture
def test_module():
return Scaler(input_key="x", output_key="scaled_x", mean=2.0, std=0.3)


def test_popmon_module(test_module):
datastore = {"x": np.arange(10)}
datastore = test_module.transform(datastore)

assert "x" in datastore # check if key 'x' is still in the datastore
np.testing.assert_almost_equal(np.mean(datastore["scaled_x"]), 2.0, decimal=5)
np.testing.assert_almost_equal(np.std(datastore["scaled_x"]), 0.3, decimal=5)


def test_popmon_module_repr(test_module):
assert str(test_module) == "Scaler(input_key='x', output_key='scaled_x')"
26 changes: 19 additions & 7 deletions tests/popmon/base/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import numpy as np
import pytest

from popmon.base import Module, Pipeline

Expand Down Expand Up @@ -65,16 +66,12 @@ def transform(self, input_array: np.ndarray, weights: np.ndarray):
return result


def test_popmon_pipeline():
@pytest.fixture
def test_pipeline():
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)

datastore = {"x": np.array([7, 2, 7, 9, 6]), "weights": np.array([1, 1, 2, 1, 2])}
expected_result = np.sum(
np.power(np.log(datastore["x"]), 2) * datastore["weights"]
) / np.sum(datastore["weights"])

log_pow_pipeline = Pipeline(
modules=[
LogTransformer(input_key="x", output_key="log_x"),
Expand All @@ -92,5 +89,20 @@ def test_popmon_pipeline():
],
logger=logger,
)
return pipeline


def test_popmon_pipeline(test_pipeline):
datastore = {"x": np.array([7, 2, 7, 9, 6]), "weights": np.array([1, 1, 2, 1, 2])}
expected_result = np.sum(
np.power(np.log(datastore["x"]), 2) * datastore["weights"]
) / np.sum(datastore["weights"])

assert pipeline.transform(datastore)["res"] == expected_result
assert test_pipeline.transform(datastore)["res"] == expected_result


def test_pipeline_repr(test_pipeline):
assert (
str(test_pipeline)
== """Pipeline: [\n\tPipeline: [\n\t\tLogTransformer(input_key='x', output_key='log_x')\n\t\tPowerTransformer(input_key='log_x', output_key='log_pow_x')\n\t]\n\tSumNormalizer(input_key='weights', output_key='norm_weights')\n\tWeightedSum(input_key='log_pow_x', weight_key='norm_weights', output_key='res')\n]"""
)

0 comments on commit 1532239

Please sign in to comment.