Skip to content

Commit

Permalink
Merge pull request #512 from gpetretto/devel
Browse files Browse the repository at this point in the history
Optionally avoid deserialization when resolving references
  • Loading branch information
utf authored Jan 7, 2024
2 parents c60b8d0 + 2917456 commit af862b3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
28 changes: 23 additions & 5 deletions src/jobflow/core/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def resolve(
store: jobflow.JobStore | None,
cache: dict[str, Any] = None,
on_missing: OnMissing = OnMissing.ERROR,
deserialize: bool = True,
) -> Any:
"""
Resolve the reference.
Expand All @@ -127,6 +128,10 @@ def resolve(
on_missing
What to do if the output reference is missing in the database and cache.
See :obj:`OnMissing` for the available options.
deserialize
If False, the data extracted from the store will not be deserialized.
Note that in this case, if a reference contains a derived property,
it cannot be resolved.
Raises
------
Expand Down Expand Up @@ -170,7 +175,8 @@ def resolve(
data = cache[self.uuid][index]

# decode objects before attribute access
data = MontyDecoder().process_decoded(data)
if deserialize:
data = MontyDecoder().process_decoded(data)

# re-cache data in case other references need it
cache[self.uuid][index] = data
Expand Down Expand Up @@ -304,6 +310,7 @@ def resolve_references(
store: jobflow.JobStore,
cache: dict[str, Any] = None,
on_missing: OnMissing = OnMissing.ERROR,
deserialize: bool = True,
) -> dict[OutputReference, Any]:
"""
Resolve multiple output references.
Expand All @@ -321,6 +328,10 @@ def resolve_references(
on_missing
What to do if the output reference is missing in the database and cache.
See :obj:`OnMissing` for the available options.
deserialize
If False, the data extracted from the store will not be deserialized.
Note that in this case, if a reference contains a derived property,
it cannot be resolved.
Returns
-------
Expand Down Expand Up @@ -348,7 +359,7 @@ def resolve_references(

for ref in ref_group:
resolved_references[ref] = ref.resolve(
store, cache=cache, on_missing=on_missing
store, cache=cache, on_missing=on_missing, deserialize=deserialize
)

return resolved_references
Expand Down Expand Up @@ -397,6 +408,7 @@ def find_and_resolve_references(
store: jobflow.JobStore,
cache: dict[str, Any] = None,
on_missing: OnMissing = OnMissing.ERROR,
deserialize: bool = True,
) -> Any:
"""
Return the input but with all output references replaced with their resolved values.
Expand All @@ -415,6 +427,10 @@ def find_and_resolve_references(
on_missing
What to do if the output reference is missing in the database and cache.
See :obj:`OnMissing` for the available options.
deserialize
If False, the data extracted from the store will not be deserialized.
Note that in this case, if a reference contains a derived property,
it cannot be resolved.
Returns
-------
Expand All @@ -428,12 +444,14 @@ def find_and_resolve_references(
from jobflow.utils.find import find_key_value

if isinstance(arg, dict) and arg.get("@class") == "OutputReference":
# if arg is a deserialized reference, serialize it
# if arg is a serialized reference, deserialize it
arg = OutputReference.from_dict(arg)

if isinstance(arg, OutputReference):
# if the argument is a reference then stop there
return arg.resolve(store, cache=cache, on_missing=on_missing)
return arg.resolve(
store, cache=cache, on_missing=on_missing, deserialize=deserialize
)

if isinstance(arg, (float, int, str, bool)):
# argument is a primitive, we won't find a reference here
Expand All @@ -453,7 +471,7 @@ def find_and_resolve_references(
OutputReference.from_dict(get(encoded_arg, list(loc))) for loc in locations
]
resolved_references = resolve_references(
references, store, cache=cache, on_missing=on_missing
references, store, cache=cache, on_missing=on_missing, deserialize=deserialize
)

# replace the references in the arg dict
Expand Down
41 changes: 40 additions & 1 deletion tests/core/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,30 +362,64 @@ def test_find_and_get_references():


def test_find_and_resolve_references(memory_jobstore):
from monty.json import MSONable

from jobflow.core.reference import (
OnMissing,
OutputReference,
find_and_resolve_references,
)

global WithProp

class WithProp(MSONable):
def __init__(self, x):
self.x = x

@property
def plus(self):
return self.x + 1

ref1 = OutputReference("123")
ref2 = OutputReference("1234", (("i", "a"),))
ref_attr = OutputReference("123456", (("a", "x"),))
ref_prop = OutputReference("123456", (("a", "plus"),))
memory_jobstore.update({"uuid": "123", "index": 1, "output": 101})
memory_jobstore.update({"uuid": "1234", "index": 1, "output": {"a": "xyz", "b": 5}})
memory_jobstore.update({"uuid": "123456", "index": 1, "output": WithProp(1)})

# test no reference
assert find_and_resolve_references(arg=True, store=memory_jobstore) is True
assert find_and_resolve_references("xyz", memory_jobstore) == "xyz"
assert (
find_and_resolve_references("xyz", memory_jobstore, deserialize=False) == "xyz"
)
assert find_and_resolve_references([101], memory_jobstore) == [101]
assert find_and_resolve_references([101], memory_jobstore, deserialize=False) == [
101
]

# test single reference
assert find_and_resolve_references(ref1, memory_jobstore) == 101
assert find_and_resolve_references(ref1, memory_jobstore, deserialize=False) == 101

# test single reference with object
assert find_and_resolve_references(ref_attr, memory_jobstore) == 1
assert find_and_resolve_references(ref_prop, memory_jobstore) == 2
assert (
find_and_resolve_references(ref_attr, memory_jobstore, deserialize=False) == 1
)
with pytest.raises(KeyError, match="plus"):
find_and_resolve_references(ref_prop, memory_jobstore, deserialize=False)

# test list and tuple of references
assert find_and_resolve_references([ref1], memory_jobstore) == [101]
assert find_and_resolve_references([ref1, ref2], memory_jobstore) == [101, "xyz"]
assert find_and_resolve_references(
[ref1, ref2], memory_jobstore, deserialize=False
) == [101, "xyz"]

# test dictionary dictionary values
# test dictionary values
output = find_and_resolve_references({"a": ref1}, memory_jobstore)
assert output == {"a": 101}
output = find_and_resolve_references({"a": ref1, "b": ref2}, memory_jobstore)
Expand Down Expand Up @@ -433,6 +467,11 @@ def test_find_and_resolve_references(memory_jobstore):
[ref1, ref3], memory_jobstore, on_missing=OnMissing.ERROR
)

with pytest.raises(ValueError, match="Could not resolve reference"):
find_and_resolve_references(
[ref1, ref3], memory_jobstore, on_missing=OnMissing.ERROR, deserialize=False
)


def test_circular_resolve(memory_jobstore):
from jobflow.core.reference import OutputReference
Expand Down

0 comments on commit af862b3

Please sign in to comment.