diff --git a/src/taskgraph/optimize/base.py b/src/taskgraph/optimize/base.py index c40b9a87..c6e84a5b 100644 --- a/src/taskgraph/optimize/base.py +++ b/src/taskgraph/optimize/base.py @@ -28,10 +28,12 @@ registry = {} -def register_strategy(name, args=()): +def register_strategy(name, args=(), kwargs=None): + kwargs = kwargs or {} + def wrap(cls): if name not in registry: - registry[name] = cls(*args) + registry[name] = cls(*args, **kwargs) if not hasattr(registry[name], "description"): registry[name].description = name return cls diff --git a/test/test_optimize.py b/test/test_optimize.py index 86beb13c..0769c544 100644 --- a/test/test_optimize.py +++ b/test/test_optimize.py @@ -9,7 +9,13 @@ from taskgraph.graph import Graph from taskgraph.optimize import base as optimize_mod -from taskgraph.optimize.base import All, Any, Not, OptimizationStrategy +from taskgraph.optimize.base import ( + All, + Any, + Not, + OptimizationStrategy, + register_strategy, +) from taskgraph.task import Task from taskgraph.taskgraph import TaskGraph @@ -467,3 +473,10 @@ def test_get_subgraph_removed_dep(): graph = make_triangle() with pytest.raises(Exception): optimize_mod.get_subgraph(graph, {"t2"}, set(), {}) + + +def test_register_strategy(mocker): + m = mocker.Mock() + func = register_strategy("foo", args=("one", "two"), kwargs={"n": 1}) + func(m) + m.assert_called_with("one", "two", n=1)