Skip to content

Commit

Permalink
Merge pull request #100 from mgxd/fix/cache-propagation
Browse files Browse the repository at this point in the history
Fix/cache propagation
  • Loading branch information
satra authored Jul 26, 2019
2 parents b466623 + 7d51095 commit f59083f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TaskBase:
# TODO: write state should be removed
def __init__(
self,
name,
name: str,
inputs: ty.Union[ty.Text, File, ty.Dict, None] = None,
audit_flags: AuditFlag = AuditFlag.NONE,
messengers=None,
Expand Down Expand Up @@ -122,6 +122,7 @@ def __init__(
)
self.cache_dir = cache_dir
self.cache_locations = cache_locations
self.allow_cache_override = True
self._checksum = None

# dictionary of results from tasks
Expand Down
2 changes: 2 additions & 0 deletions pydra/engine/submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __call__(self, runnable, cache_locations=None):
if is_workflow(runnable):
for nd in runnable.graph.nodes:
runnable.create_connections(nd)
if nd.allow_cache_override:
nd.cache_dir = runnable.cache_dir
runnable.inputs._graph_checksums = [
nd.checksum for nd in runnable.graph_sorted
]
Expand Down
41 changes: 41 additions & 0 deletions pydra/engine/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,3 +1754,44 @@ def test_wf_ndstate_cachelocations_recompute(plugin, tmpdir):
# checking if the second wf didn't run again
# checking all directories
assert wf2.output_dir.exists()


@pytest.fixture
def create_tasks():
wf = Workflow(name="wf", input_spec=["x"])
wf.inputs.x = 1
wf.add(add2(name="t1", x=wf.lzin.x))
wf.add(multiply(name="t2", x=wf.t1.lzout.out, y=2))
wf.set_output([("out", wf.t2.lzout.out)])
t1 = wf.name2obj["t1"]
t2 = wf.name2obj["t2"]
return wf, t1, t2


def test_cache_propagation1(tmpdir, create_tasks):
"""No cache set, all independent"""
wf, t1, t2 = create_tasks
wf(plugin="cf")
assert wf.cache_dir == t1.cache_dir == t2.cache_dir
wf.cache_dir = (tmpdir / "shared").strpath
wf(plugin="cf")
assert wf.cache_dir == t1.cache_dir == t2.cache_dir


def test_cache_propagation2(tmpdir, create_tasks):
"""Task explicitly states no inheriting"""
wf, t1, t2 = create_tasks
wf.cache_dir = (tmpdir / "shared").strpath
t2.allow_cache_override = False
wf(plugin="cf")
assert wf.cache_dir == t1.cache_dir != t2.cache_dir


def test_cache_propagation3(tmpdir, create_tasks):
"""Shared cache_dir with state"""
wf, t1, t2 = create_tasks
wf.inputs.x = [1, 2]
wf.split("x")
wf.cache_dir = (tmpdir / "shared").strpath
wf(plugin="cf")
assert wf.cache_dir == t1.cache_dir == t2.cache_dir

0 comments on commit f59083f

Please sign in to comment.