Skip to content

Commit

Permalink
include ZNTRACK_FIELD_SUFFIX
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Oct 29, 2024
1 parent 3029729 commit b84d54d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
6 changes: 6 additions & 0 deletions zntrack/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class _ZNTRACK_FIELD_DUMP_TYPE:
ZNTRACK_FIELD_DUMP = _ZNTRACK_FIELD_DUMP_TYPE()


class _ZNTRACK_FIELD_SUFFIX_TYPE:
pass

ZNTRACK_FIELD_SUFFIX = _ZNTRACK_FIELD_SUFFIX_TYPE()


class NodeStatusEnum(enum.Enum):
CREATED = 0
RUNNING = 2
Expand Down
15 changes: 9 additions & 6 deletions zntrack/fields/outs_and_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,25 @@
ZNTRACK_FIELD_DUMP,
ZNTRACK_FIELD_LOAD,
ZNTRACK_INDEPENDENT_OUTPUT_TYPE,
ZNTRACK_FIELD_SUFFIX,
ZNTRACK_OPTION,
ZnTrackOptionEnum,
)
from zntrack.node import Node
from zntrack.plugins import base_getter, plugin_getter


def _outs_getter(self: "Node", name: str):
with self.state.fs.open((self.nwd / name).with_suffix(".json")) as f:
def _outs_getter(self: "Node", name: str, suffix: str):
with self.state.fs.open((self.nwd / name).with_suffix(suffix)) as f:
self.__dict__[name] = json.load(f, cls=znjson.ZnDecoder)


def _outs_save_func(self: "Node", name: str):
(self.nwd / name).with_suffix(".json").write_text(znjson.dumps(getattr(self, name)))
def _outs_save_func(self: "Node", name: str, suffix: str):
(self.nwd / name).with_suffix(suffix).write_text(znjson.dumps(getattr(self, name)))


def _metrics_save_func(self: "Node", name: str):
(self.nwd / name).with_suffix(".json").write_text(json.dumps(getattr(self, name)))
def _metrics_save_func(self: "Node", name: str, suffix: str):
(self.nwd / name).with_suffix(suffix).write_text(json.dumps(getattr(self, name)))


def outs(*, cache: bool = True, independent: bool = False, **kwargs) -> znfields.field:
Expand All @@ -39,6 +40,7 @@ def outs(*, cache: bool = True, independent: bool = False, **kwargs) -> znfields
base_getter, func=_outs_getter
)
kwargs["metadata"][ZNTRACK_FIELD_DUMP] = _outs_save_func
kwargs["metadata"][ZNTRACK_FIELD_SUFFIX] = ".json"
return znfields.field(
default=NOT_AVAILABLE, getter=plugin_getter, **kwargs, init=False
)
Expand All @@ -55,6 +57,7 @@ def metrics(
base_getter, func=_outs_getter
)
kwargs["metadata"][ZNTRACK_FIELD_DUMP] = _metrics_save_func
kwargs["metadata"][ZNTRACK_FIELD_SUFFIX] = ".json"
return znfields.field(
default=NOT_AVAILABLE, getter=plugin_getter, **kwargs, init=False
)
12 changes: 9 additions & 3 deletions zntrack/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def finalize(cls, **kwargs) -> None:
return


def base_getter(self: "Node", name: str, func: t.Callable):
def base_getter(self: "Node", name: str, func: t.Callable, suffix: t.Optional[str] = None):
if (
name in self.__dict__
and self.__dict__[name] is not ZNTRACK_LAZY_VALUE
Expand All @@ -115,12 +115,18 @@ def base_getter(self: "Node", name: str, func: t.Callable):

if name in self.__dict__ and self.__dict__[name] is NOT_AVAILABLE:
try:
func(self, name)
if suffix is not None:
func(self, name, suffix)
else:
func(self, name)
except FileNotFoundError:
return NOT_AVAILABLE

try:
func(self, name)
if suffix is not None:
func(self, name, suffix)
else:
func(self, name)
except FileNotFoundError:
return NOT_AVAILABLE

Expand Down
11 changes: 10 additions & 1 deletion zntrack/plugins/dvc_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ZNTRACK_FIELD_LOAD,
ZNTRACK_FILE_PATH,
ZNTRACK_LAZY_VALUE,
ZNTRACK_FIELD_SUFFIX,
ZNTRACK_OPTION,
ZNTRACK_OPTION_PLOTS_CONFIG,
ZnTrackOptionEnum,
Expand All @@ -44,15 +45,23 @@
class DVCPlugin(ZnTrackPlugin):
def getter(self, field: dataclasses.Field) -> t.Any:
getter = field.metadata.get(ZNTRACK_FIELD_LOAD)
suffix = field.metadata.get(ZNTRACK_FIELD_SUFFIX)

if getter is not None:
if suffix is not None:
return getter(self.node, field.name, suffix=suffix)
return getter(self.node, field.name)
return PLUGIN_EMPTY_RETRUN_VALUE

def save(self, field: dataclasses.Field) -> None:
dump_func = field.metadata.get(ZNTRACK_FIELD_DUMP)
suffix = field.metadata.get(ZNTRACK_FIELD_SUFFIX)

if dump_func is not None:
dump_func(self.node, field.name)
if suffix is not None:
dump_func(self.node, field.name, suffix=suffix)
else:
dump_func(self.node, field.name)

def convert_to_params_yaml(self) -> dict | object:
data = {}
Expand Down

0 comments on commit b84d54d

Please sign in to comment.