Skip to content

Commit

Permalink
type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Oct 28, 2024
1 parent 43b39e9 commit 4e18ecd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/test_utils/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_are_equal(self) -> None:
get_standardization_pipeline,
get_morgan_physchem_rf_pipeline,
]
for pipeline_method in pipline_method_list:
for pipeline_method in pipline_method_list: # type: Callable[[int], Pipeline]
pipeline_a = pipeline_method()
pipeline_b = pipeline_method()
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b))
Expand Down
13 changes: 9 additions & 4 deletions tests/utils/default_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@
from molpipeline.post_prediction import PostPredictionWrapper


def get_morgan_physchem_rf_pipeline() -> Pipeline:
def get_morgan_physchem_rf_pipeline(n_jobs: int = 1) -> Pipeline:
"""Get a pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier.
Parameters
----------
n_jobs: int, default=-1
Number of parallel jobs to use.
Returns
-------
Pipeline
Expand All @@ -47,20 +52,20 @@ def get_morgan_physchem_rf_pipeline() -> Pipeline:
),
),
("error_filter", error_filter),
("rf", RandomForestClassifier()),
("rf", RandomForestClassifier(n_jobs=n_jobs)),
(
"filter_reinserter",
PostPredictionWrapper(
FilterReinserter.from_error_filter(error_filter, None)
),
),
],
n_jobs=1,
n_jobs=n_jobs,
)
return pipeline


def get_standardization_pipeline(n_jobs: int = -1) -> Pipeline:
def get_standardization_pipeline(n_jobs: int = 1) -> Pipeline:
"""Get the standardization pipeline.
Parameters
Expand Down

0 comments on commit 4e18ecd

Please sign in to comment.