Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally avoid deserialization when resolving references #512

Merged
merged 2 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/jobflow/core/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def resolve(
store: jobflow.JobStore | None,
cache: dict[str, Any] = None,
on_missing: OnMissing = OnMissing.ERROR,
deserialize: bool = True,
) -> Any:
"""
Resolve the reference.
Expand All @@ -127,6 +128,10 @@ def resolve(
on_missing
What to do if the output reference is missing in the database and cache.
See :obj:`OnMissing` for the available options.
deserialize
If False, the data extracted from the store will not be deserialized.
Note that in this case, if a reference contains a derived property,
it cannot be resolved.

Raises
------
Expand Down Expand Up @@ -170,7 +175,8 @@ def resolve(
data = cache[self.uuid][index]

# decode objects before attribute access
data = MontyDecoder().process_decoded(data)
if deserialize:
data = MontyDecoder().process_decoded(data)

# re-cache data in case other references need it
cache[self.uuid][index] = data
Expand Down Expand Up @@ -304,6 +310,7 @@ def resolve_references(
store: jobflow.JobStore,
cache: dict[str, Any] = None,
on_missing: OnMissing = OnMissing.ERROR,
deserialize: bool = True,
) -> dict[OutputReference, Any]:
"""
Resolve multiple output references.
Expand All @@ -321,6 +328,10 @@ def resolve_references(
on_missing
What to do if the output reference is missing in the database and cache.
See :obj:`OnMissing` for the available options.
deserialize
If False, the data extracted from the store will not be deserialized.
Note that in this case, if a reference contains a derived property,
it cannot be resolved.

Returns
-------
Expand Down Expand Up @@ -348,7 +359,7 @@ def resolve_references(

for ref in ref_group:
resolved_references[ref] = ref.resolve(
store, cache=cache, on_missing=on_missing
store, cache=cache, on_missing=on_missing, deserialize=deserialize
)

return resolved_references
Expand Down Expand Up @@ -397,6 +408,7 @@ def find_and_resolve_references(
store: jobflow.JobStore,
cache: dict[str, Any] = None,
on_missing: OnMissing = OnMissing.ERROR,
deserialize: bool = True,
) -> Any:
"""
Return the input but with all output references replaced with their resolved values.
Expand All @@ -415,6 +427,10 @@ def find_and_resolve_references(
on_missing
What to do if the output reference is missing in the database and cache.
See :obj:`OnMissing` for the available options.
deserialize
If False, the data extracted from the store will not be deserialized.
Note that in this case, if a reference contains a derived property,
it cannot be resolved.

Returns
-------
Expand All @@ -428,12 +444,14 @@ def find_and_resolve_references(
from jobflow.utils.find import find_key_value

if isinstance(arg, dict) and arg.get("@class") == "OutputReference":
# if arg is a deserialized reference, serialize it
# if arg is a serialized reference, deserialize it
arg = OutputReference.from_dict(arg)

if isinstance(arg, OutputReference):
# if the argument is a reference then stop there
return arg.resolve(store, cache=cache, on_missing=on_missing)
return arg.resolve(
store, cache=cache, on_missing=on_missing, deserialize=deserialize
)

if isinstance(arg, (float, int, str, bool)):
# argument is a primitive, we won't find a reference here
Expand All @@ -453,7 +471,7 @@ def find_and_resolve_references(
OutputReference.from_dict(get(encoded_arg, list(loc))) for loc in locations
]
resolved_references = resolve_references(
references, store, cache=cache, on_missing=on_missing
references, store, cache=cache, on_missing=on_missing, deserialize=deserialize
)

# replace the references in the arg dict
Expand Down
41 changes: 40 additions & 1 deletion tests/core/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,30 +362,64 @@ def test_find_and_get_references():


def test_find_and_resolve_references(memory_jobstore):
from monty.json import MSONable

from jobflow.core.reference import (
OnMissing,
OutputReference,
find_and_resolve_references,
)

global WithProp

class WithProp(MSONable):
def __init__(self, x):
self.x = x

@property
def plus(self):
return self.x + 1

ref1 = OutputReference("123")
ref2 = OutputReference("1234", (("i", "a"),))
ref_attr = OutputReference("123456", (("a", "x"),))
ref_prop = OutputReference("123456", (("a", "plus"),))
memory_jobstore.update({"uuid": "123", "index": 1, "output": 101})
memory_jobstore.update({"uuid": "1234", "index": 1, "output": {"a": "xyz", "b": 5}})
memory_jobstore.update({"uuid": "123456", "index": 1, "output": WithProp(1)})

# test no reference
assert find_and_resolve_references(arg=True, store=memory_jobstore) is True
assert find_and_resolve_references("xyz", memory_jobstore) == "xyz"
assert (
find_and_resolve_references("xyz", memory_jobstore, deserialize=False) == "xyz"
)
assert find_and_resolve_references([101], memory_jobstore) == [101]
assert find_and_resolve_references([101], memory_jobstore, deserialize=False) == [
101
]

# test single reference
assert find_and_resolve_references(ref1, memory_jobstore) == 101
assert find_and_resolve_references(ref1, memory_jobstore, deserialize=False) == 101

# test single reference with object
assert find_and_resolve_references(ref_attr, memory_jobstore) == 1
assert find_and_resolve_references(ref_prop, memory_jobstore) == 2
assert (
find_and_resolve_references(ref_attr, memory_jobstore, deserialize=False) == 1
)
with pytest.raises(KeyError, match="plus"):
find_and_resolve_references(ref_prop, memory_jobstore, deserialize=False)

# test list and tuple of references
assert find_and_resolve_references([ref1], memory_jobstore) == [101]
assert find_and_resolve_references([ref1, ref2], memory_jobstore) == [101, "xyz"]
assert find_and_resolve_references(
[ref1, ref2], memory_jobstore, deserialize=False
) == [101, "xyz"]

# test dictionary dictionary values
# test dictionary values
output = find_and_resolve_references({"a": ref1}, memory_jobstore)
assert output == {"a": 101}
output = find_and_resolve_references({"a": ref1, "b": ref2}, memory_jobstore)
Expand Down Expand Up @@ -433,6 +467,11 @@ def test_find_and_resolve_references(memory_jobstore):
[ref1, ref3], memory_jobstore, on_missing=OnMissing.ERROR
)

with pytest.raises(ValueError, match="Could not resolve reference"):
find_and_resolve_references(
[ref1, ref3], memory_jobstore, on_missing=OnMissing.ERROR, deserialize=False
)


def test_circular_resolve(memory_jobstore):
from jobflow.core.reference import OutputReference
Expand Down
Loading