Skip to content

Commit

Permalink
move deps
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Oct 29, 2024
1 parent 2c2c059 commit 7bb5efc
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 56 deletions.
1 change: 0 additions & 1 deletion tests/files/test_custom_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class NodeFromCustomModule(zntrack.Node):

_module_ = "zntrack.mymodule"

def run(self):
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_post_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class NodeWithPostInit(zntrack.Node):

def __post_init__(self):
self.value = 42

Expand Down
2 changes: 1 addition & 1 deletion zntrack-examples
Submodule zntrack-examples updated 1 files
+4 −9 main.py
22 changes: 9 additions & 13 deletions zntrack/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
ZNTRACK_OPTION_PLOTS_CONFIG,
ZnTrackOptionEnum,
)
from zntrack.fields.x_path import metrics_path, outs_path, params_path, plots_path, deps_path
from zntrack.fields.deps import deps
from zntrack.fields.x_path import (
deps_path,
metrics_path,
outs_path,
params_path,
plots_path,
)
from zntrack.node import Node
from zntrack.plugins import plugin_getter

Expand All @@ -21,7 +28,7 @@
# TODO: zntrack.outs() and zntrack.outs(cache=False) needs different files!


__all__ = ["outs_path", "params_path", "plots_path", "metrics_path", "deps_path"]
__all__ = ["outs_path", "params_path", "plots_path", "metrics_path", "deps_path", "deps"]


def _plots_autosave_setter(self: Node, name: str, value: pd.DataFrame):
Expand All @@ -39,15 +46,6 @@ def params(default=dataclasses.MISSING, **kwargs) -> znfields.field:
)


def deps(default=dataclasses.MISSING, **kwargs) -> znfields.field:
return znfields.field(
default=default,
metadata={ZNTRACK_OPTION: ZnTrackOptionEnum.DEPS},
getter=plugin_getter,
**kwargs,
)


def outs(*, cache: bool = True, independent: bool = False, **kwargs) -> znfields.field:
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.OUTS
Expand Down Expand Up @@ -136,5 +134,3 @@ def metrics(
return znfields.field(
default=NOT_AVAILABLE, getter=plugin_getter, **kwargs, init=False
)


63 changes: 63 additions & 0 deletions zntrack/fields/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import dataclasses
import json

import znfields
import znflow
import znflow.handler
import znflow.utils
import znjson

from zntrack import converter
from zntrack.config import (
ZNTRACK_FIELD_GETTER,
ZNTRACK_FILE_PATH,
ZNTRACK_OPTION,
ZnTrackOptionEnum,
)
from zntrack.node import Node
from zntrack.plugins import base_getter, plugin_getter
import functools

# if t.TYPE_CHECKING:


def _deps_getter(self: "Node", name: str):
with self.state.fs.open(ZNTRACK_FILE_PATH) as f:
content = json.load(f)[self.name][name]
# TODO: Ensure deps are loaded from the correct revision
content = znjson.loads(
json.dumps(content),
cls=znjson.ZnDecoder.from_converters(
[
converter.NodeConverter,
converter.ConnectionConverter,
converter.CombinedConnectionsConverter,
converter.DVCImportPathConverter,
converter.DataclassConverter,
],
add_default=True,
),
)
if isinstance(content, converter.DataclassContainer):
content = content.get_with_params(self.name, name)
if isinstance(content, list):
new_content = []
idx = 0
for val in content:
if isinstance(val, converter.DataclassContainer):
new_content.append(val.get_with_params(self.name, name, idx))
idx += 1 # index only runs over dataclasses
else:
new_content.append(val)
content = new_content

content = znflow.handler.UpdateConnectors()(content)

self.__dict__[name] = content


def deps(default=dataclasses.MISSING, **kwargs) -> znfields.field:
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.DEPS
kwargs["metadata"][ZNTRACK_FIELD_GETTER] = functools.partial(base_getter, func=_deps_getter)
return znfields.field(default=default, getter=plugin_getter, **kwargs)
3 changes: 1 addition & 2 deletions zntrack/fields/x_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,11 @@ def metrics_path(
return znfields.field(default=default, getter=plugin_getter, **kwargs)



def deps_path(
default=dataclasses.MISSING, *, cache: bool = True, **kwargs
) -> znfields.field:
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.DEPS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_FIELD_GETTER] = _paths_getter
return znfields.field(default=default, getter=plugin_getter, **kwargs)
return znfields.field(default=default, getter=plugin_getter, **kwargs)
38 changes: 0 additions & 38 deletions zntrack/plugins/dvc_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,41 +79,6 @@ def _paths_getter(self: "Node", name: str):
return NOT_AVAILABLE


def _deps_getter(self: "Node", name: str):
with self.state.fs.open(ZNTRACK_FILE_PATH) as f:
content = json.load(f)[self.name][name]
# TODO: Ensure deps are loaded from the correct revision
content = znjson.loads(
json.dumps(content),
cls=znjson.ZnDecoder.from_converters(
[
converter.NodeConverter,
converter.ConnectionConverter,
converter.CombinedConnectionsConverter,
converter.DVCImportPathConverter,
converter.DataclassConverter,
],
add_default=True,
),
)
if isinstance(content, converter.DataclassContainer):
content = content.get_with_params(self.name, name)
if isinstance(content, list):
new_content = []
idx = 0
for val in content:
if isinstance(val, converter.DataclassContainer):
new_content.append(val.get_with_params(self.name, name, idx))
idx += 1 # index only runs over dataclasses
else:
new_content.append(val)
content = new_content

content = znflow.handler.UpdateConnectors()(content)

self.__dict__[name] = content


def _params_getter(self: "Node", name: str):
with self.state.fs.open(PARAMS_FILE_PATH) as f:
self.__dict__[name] = yaml.safe_load(f)[self.name][name]
Expand All @@ -137,9 +102,6 @@ def getter(self, field: dataclasses.Field) -> t.Any:

if getter is not None:
return getter(self.node, field.name)

if option == ZnTrackOptionEnum.DEPS:
return base_getter(self.node, field.name, _deps_getter)
elif option == ZnTrackOptionEnum.PARAMS:
return base_getter(self.node, field.name, _params_getter)
elif option == ZnTrackOptionEnum.PLOTS:
Expand Down

0 comments on commit 7bb5efc

Please sign in to comment.