Skip to content

Commit

Permalink
Christians comments, logging test
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 20, 2024
1 parent 8d1abd1 commit d5705e4
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 40 deletions.
127 changes: 101 additions & 26 deletions molpipeline/mol2any/mol2concatinated_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import numpy.typing as npt
from sklearn.base import clone

from loguru import logger

from molpipeline.abstract_pipeline_elements.core import (
InvalidInstance,
MolToAnyPipelineElement,
Expand All @@ -32,7 +34,7 @@ class MolToConcatenatedVector(MolToAnyPipelineElement):
def __init__(
self,
element_list: list[tuple[str, MolToAnyPipelineElement]],
feature_names_prefix: Optional[str] = None,
use_feature_names_prefix: bool = True,
name: str = "MolToConcatenatedVector",
n_jobs: int = 1,
uuid: Optional[str] = None,
Expand All @@ -44,8 +46,10 @@ def __init__(
----------
element_list: list[MolToAnyPipelineElement]
List of Pipeline Elements of which the output is concatenated.
feature_names_prefix: str, optional (default=None)
Prefix for feature names. If None, the name of the pipeline element is used.
use_feature_names_prefix: bool, optional (default=True)
If True, will add the pipeline element's name as prefix to feature names.
If False, only the feature names are used. This can lead to duplicate
feature names.
name: str, optional (default="MolToConcatenatedVector")
name of pipeline.
n_jobs: int, optional (default=1)
Expand All @@ -58,18 +62,13 @@ def __init__(
self._element_list = element_list
if len(element_list) == 0:
raise ValueError("element_list must contain at least one element.")
self._feature_names_prefix = feature_names_prefix
self._use_feature_names_prefix = use_feature_names_prefix
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)
output_types = set()
for _, element in self._element_list:
element.n_jobs = self.n_jobs
output_types.add(element.output_type)
if len(output_types) == 1:
self._output_type = output_types.pop()
else:
self._output_type = "mixed"
self._requires_fitting = any(
element[1]._requires_fitting for element in element_list
# set element execution details
self._set_element_execution_details(self._element_list)
# set feature names
self._feature_names = self._create_feature_names(
self._element_list, self._use_feature_names_prefix
)
self.set_params(**kwargs)

Expand All @@ -96,25 +95,79 @@ def n_features(self) -> int:
@property
def feature_names(self) -> list[str]:
"""Return the feature names of concatenated elements."""
return self._feature_names[:]

@staticmethod
def _create_feature_names(
element_list: list[tuple[str, MolToAnyPipelineElement]],
use_feature_names_prefix: bool,
) -> list[str]:
"""Create feature names for concatenated vector from its elements.
Parameters
----------
element_list: list[tuple[str, MolToAnyPipelineElement]]
List of pipeline elements.
use_feature_names_prefix: bool
If True, will add the pipeline element's name as prefix to feature names.
If False, only the feature names are used. This can lead to duplicate
feature names.
Raises
------
ValueError
If element does not have feature_names attribute.
Returns
-------
list[str]
List of feature names.
"""
feature_names = []
for name, element in self._element_list:
if self._feature_names_prefix is None:
# use element name as prefix
prefix = name
else:
# use user specified prefix
prefix = self._feature_names_prefix
for name, element in element_list:
if not hasattr(element, "feature_names"):
raise ValueError(
f"Element {element} does not have feature_names attribute."
)

if hasattr(element, "feature_names"):
if use_feature_names_prefix:
# use element name as prefix
feature_names.extend(
[f"{prefix}__{feature}" for feature in element.feature_names]
[f"{name}__{feature}" for feature in element.feature_names] # type: ignore[attr-defined]
)
else:
raise ValueError(
f"Element {element} does not have feature_names attribute."
)
feature_names.extend(element.feature_names) # type: ignore[attr-defined]

if len(feature_names) != len(set(feature_names)):
logger.warning(
"Feature names in MolToConcatenatedVector are not unique."
" Set use_feature_names_prefix=True and use unique pipeline element"
" names to avoid this."
)
return feature_names

def _set_element_execution_details(
self, element_list: list[tuple[str, MolToAnyPipelineElement]]
) -> None:
"""Set output type and requires fitting for the concatenated vector.
Parameters
----------
element_list: list[tuple[str, MolToAnyPipelineElement]]
List of pipeline elements.
"""
output_types = set()
for _, element in self._element_list:
element.n_jobs = self.n_jobs
output_types.add(element.output_type)
if len(output_types) == 1:
self._output_type = output_types.pop()
else:
self._output_type = "mixed"
self._requires_fitting = any(
element[1]._requires_fitting for element in element_list
)

def get_params(self, deep: bool = True) -> dict[str, Any]:
"""Return all parameters defining the object.
Expand All @@ -133,8 +186,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
parameters["element_list"] = [
(str(name), clone(ele)) for name, ele in self.element_list
]
parameters["use_feature_names_prefix"] = bool(
self._use_feature_names_prefix
)
else:
parameters["element_list"] = self.element_list
parameters["use_feature_names_prefix"] = self._use_feature_names_prefix
for name, element in self.element_list:
for key, value in element.get_params(deep=deep).items():
parameters[f"{name}__{key}"] = value
Expand All @@ -155,9 +212,15 @@ def set_params(self, **parameters: Any) -> Self:
Mol2ConcatenatedVector object with updated parameters.
"""
parameter_copy = dict(parameters)

# handle element_list
element_list = parameter_copy.pop("element_list", None)
if element_list is not None:
self._element_list = element_list
if len(element_list) == 0:
raise ValueError("element_list must contain at least one element.")
# reset element execution details
self._set_element_execution_details(self._element_list)
step_params: dict[str, dict[str, Any]] = {}
step_dict = dict(self._element_list)
to_delete_list = []
Expand All @@ -178,6 +241,18 @@ def set_params(self, **parameters: Any) -> Self:
_ = parameter_copy.pop(to_delete, None)
for step, params in step_params.items():
step_dict[step].set_params(**params)

# handle use_feature_names_prefix
use_feature_names_prefix = parameter_copy.pop("use_feature_names_prefix", None)
if use_feature_names_prefix is not None:
self._use_feature_names_prefix = use_feature_names_prefix
# reset feature names
self._feature_names = self._create_feature_names(
self._element_list,
self._use_feature_names_prefix, # type: ignore[arg-type]
)

# set parameters of super
super().set_params(**parameter_copy)
return self

Expand Down
135 changes: 121 additions & 14 deletions tests/test_elements/test_mol2any/test_mol2concatenated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MolToRDKitPhysChem,
)
from tests.utils.fingerprints import fingerprints_to_numpy
from tests.utils.logging import capture_logs


class TestConcatenatedFingerprint(unittest.TestCase):
Expand Down Expand Up @@ -94,9 +95,22 @@ def test_generation(self) -> None:

def test_empty_element_list(self) -> None:
"""Test if an empty element list raises an error."""
# test constructor
with self.assertRaises(ValueError):
MolToConcatenatedVector([])

# test setter
concat_elem = MolToConcatenatedVector(
[
(
"RDKitPhysChem",
MolToRDKitPhysChem(),
)
]
)
with self.assertRaises(ValueError):
concat_elem.set_params(element_list=[])

def test_n_features(self) -> None:
"""Test getting the number of features in the concatenated vector."""

Expand Down Expand Up @@ -147,7 +161,7 @@ def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals
net_charge_elem = ("NetCharge", MolToNetCharge())
morgan_elem = (
"MorganFP",
MolToMorganFP(n_bits=16),
MolToMorganFP(n_bits=17),
)
path_elem = (
"PathFP",
Expand All @@ -158,9 +172,15 @@ def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals
MolToMorganFP(n_bits=14),
)

elements = [physchem_elem, net_charge_elem, morgan_elem, path_elem, maccs_elem]
elements = [
physchem_elem,
net_charge_elem,
morgan_elem,
path_elem,
maccs_elem,
]

for feature_names_prefix in [None, "my_prefix"]:
for use_feature_names_prefix in [False, True]:
# test all subsets are compatible
powerset = itertools.chain.from_iterable(
itertools.combinations(elements, r) for r in range(len(elements) + 1)
Expand All @@ -170,10 +190,18 @@ def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals

for elements_subset in powerset:
conc_elem = MolToConcatenatedVector(
list(elements_subset), feature_names_prefix=feature_names_prefix
list(elements_subset),
use_feature_names_prefix=use_feature_names_prefix,
)
feature_names = conc_elem.feature_names

if use_feature_names_prefix:
# test feature names are unique if prefix is used or only one element is used
self.assertEqual(
len(feature_names),
len(set(feature_names)),
)

# test a feature names and n_features are consistent
self.assertEqual(
len(feature_names),
Expand All @@ -188,23 +216,102 @@ def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals
relevant_names = feature_names[
seen_names : seen_names + elem_n_features
]
prefixes, feat_names = map(
list, zip(*[name.split("__") for name in relevant_names])
)
# test feature names are the same
self.assertListEqual(elem_feature_names, feat_names)

if feature_names_prefix is not None:
# test prefixes are the same user given prefix
self.assertTrue(
all(prefix == feature_names_prefix for prefix in prefixes)
if use_feature_names_prefix:
# feature_names should be prefixed with element name
prefixes, feat_names = map(
list, zip(*[name.split("__") for name in relevant_names])
)
else:
# test feature names are the same
self.assertListEqual(elem_feature_names, feat_names)
# test prefixes are the same as element names
self.assertTrue(all(prefix == elem_name for prefix in prefixes))
else:
# feature_names should be the same as element feature names
self.assertListEqual(elem_feature_names, relevant_names)

seen_names += elem_n_features

def test_logging_feature_names_uniqueness(self) -> None:
"""Test that a warning is logged when feature names are not unique."""
elements = [
(
"MorganFP",
MolToMorganFP(n_bits=17),
),
(
"MorganFP_with_feats",
MolToMorganFP(n_bits=16, use_features=True),
),
]

# First test is with no prefix
use_feature_names_prefix = False
with capture_logs() as output:
conc_elem = MolToConcatenatedVector(
elements,
use_feature_names_prefix=use_feature_names_prefix,
)
feature_names = conc_elem.feature_names

# test log message
self.assertEqual(len(output), 1)
message = output[0]
self.assertIn(
"Feature names in MolToConcatenatedVector are not unique", message
)
self.assertEqual(message.record["level"].name, "WARNING")

# test feature names are NOT unique
self.assertNotEqual(len(feature_names), len(set(feature_names)))

# Second test is with prefix
use_feature_names_prefix = True
with capture_logs() as output:
conc_elem = MolToConcatenatedVector(
elements,
use_feature_names_prefix=use_feature_names_prefix,
)
feature_names = conc_elem.feature_names

# test log message
self.assertEqual(len(output), 0)

# test feature names are unique
self.assertEqual(len(feature_names), len(set(feature_names)))

def test_getter_setter(self) -> None:
"""Test getter and setter methods."""
elements = [
(
"MorganFP",
MolToMorganFP(n_bits=17),
),
(
"MorganFP_with_feats",
MolToMorganFP(n_bits=16, use_features=True),
),
]
concat_elem = MolToConcatenatedVector(
elements,
use_feature_names_prefix=True,
)
self.assertEqual(len(concat_elem.get_params()["element_list"]), 2)
self.assertEqual(concat_elem.get_params()["use_feature_names_prefix"], True)
# test that there are no duplicates in feature names
self.assertEqual(
len(concat_elem.feature_names), len(set(concat_elem.feature_names))
)
params: dict[str, Any] = {
"use_feature_names_prefix": False,
}
concat_elem.set_params(**params)
self.assertEqual(concat_elem.get_params()["use_feature_names_prefix"], False)
# test that there are duplicates in feature names
self.assertNotEqual(
len(concat_elem.feature_names), len(set(concat_elem.feature_names))
)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit d5705e4

Please sign in to comment.