Skip to content

Commit

Permalink
fix(optimize): support kwargs in 'register_strategy' decorator
Browse files Browse the repository at this point in the history
Some optimization strategies can take class level kwargs (such as the
'split_args' kwarg for all composite strategies).

Ensure these get forwarded when using the 'register_strategy' decorator.
  • Loading branch information
ahal committed Oct 2, 2024
1 parent 18ce170 commit ad118b6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/taskgraph/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
registry = {}


def register_strategy(name, args=()):
def register_strategy(name, args=(), kwargs=None):
kwargs = kwargs or {}

def wrap(cls):
if name not in registry:
registry[name] = cls(*args)
registry[name] = cls(*args, **kwargs)
if not hasattr(registry[name], "description"):
registry[name].description = name
return cls
Expand Down
15 changes: 14 additions & 1 deletion test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@

from taskgraph.graph import Graph
from taskgraph.optimize import base as optimize_mod
from taskgraph.optimize.base import All, Any, Not, OptimizationStrategy
from taskgraph.optimize.base import (
All,
Any,
Not,
OptimizationStrategy,
register_strategy,
)
from taskgraph.task import Task
from taskgraph.taskgraph import TaskGraph

Expand Down Expand Up @@ -467,3 +473,10 @@ def test_get_subgraph_removed_dep():
graph = make_triangle()
with pytest.raises(Exception):
optimize_mod.get_subgraph(graph, {"t2"}, set(), {})


def test_register_strategy(mocker):
m = mocker.Mock()
func = register_strategy("foo", args=("one", "two"), kwargs={"n": 1})
func(m)
m.assert_called_with("one", "two", n=1)

0 comments on commit ad118b6

Please sign in to comment.