Skip to content

Commit

Permalink
use ZNTRACK_FIELD_SUFFIX
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Oct 29, 2024
1 parent b84d54d commit 032b5f4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
10 changes: 6 additions & 4 deletions zntrack/fields/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,26 @@
ZNTRACK_OPTION,
ZNTRACK_OPTION_PLOTS_CONFIG,
ZnTrackOptionEnum,
ZNTRACK_FIELD_SUFFIX,
)
from zntrack.node import Node
from zntrack.plugins import base_getter, plugin_getter


def _plots_save_func(self: "Node", name: str):
def _plots_save_func(self: "Node", name: str, suffix: str):
content = getattr(self, name)
if not isinstance(content, pd.DataFrame):
raise TypeError(f"Expected a pandas DataFrame, got {type(content)}")
content.to_csv((self.nwd / name).with_suffix(".csv"))
content.to_csv((self.nwd / name).with_suffix(suffix))


def _plots_autosave_setter(self: Node, name: str, value: pd.DataFrame):
value.to_csv((self.nwd / name).with_suffix(".csv"))
self.__dict__[name] = value


def _plots_getter(self: "Node", name: str):
with self.state.fs.open((self.nwd / name).with_suffix(".csv")) as f:
def _plots_getter(self: "Node", name: str, suffix: str):
with self.state.fs.open((self.nwd / name).with_suffix(suffix)) as f:
self.__dict__[name] = pd.read_csv(f, index_col=0)


Expand Down Expand Up @@ -103,6 +104,7 @@ def plots(
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = functools.partial(
base_getter, func=_plots_getter
)
kwargs["metadata"][ZNTRACK_FIELD_SUFFIX] = ".csv"
return znfields.field(
default=NOT_AVAILABLE, getter=plugin_getter, **kwargs, init=False
)
11 changes: 7 additions & 4 deletions zntrack/plugins/dvc_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,20 @@ def convert_to_dvc_yaml(self) -> dict | object:
content = [{c: {"cache": False}} for c in content]
stages.setdefault(ZnTrackOptionEnum.METRICS.value, []).extend(content)
elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.OUTS:
content = [(self.node.nwd / field.name).with_suffix(".json").as_posix()]
suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
if field.metadata.get(ZNTRACK_CACHE) is False:
content = [{c: {"cache": False}} for c in content]
stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PLOTS:
content = [(self.node.nwd / field.name).with_suffix(".csv").as_posix()]
suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
if field.metadata.get(ZNTRACK_CACHE) is False:
content = [{c: {"cache": False}} for c in content]
stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
if ZNTRACK_OPTION_PLOTS_CONFIG in field.metadata:
file_path = (
(self.node.nwd / field.name).with_suffix(".csv").as_posix()
(self.node.nwd / field.name).with_suffix(suffix).as_posix()
)
plots_config = field.metadata[ZNTRACK_OPTION_PLOTS_CONFIG].copy()
if "x" not in plots_config or "y" not in plots_config:
Expand All @@ -217,7 +219,8 @@ def convert_to_dvc_yaml(self) -> dict | object:
plots_config["y"] = {file_path: plots_config["y"]}
plots.append({f"{self.node.name}_{field.name}": plots_config})
elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.METRICS:
content = [(self.node.nwd / field.name).with_suffix(".json").as_posix()]
suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
if field.metadata.get(ZNTRACK_CACHE) is False:
content = [{c: {"cache": False}} for c in content]
stages.setdefault(ZnTrackOptionEnum.METRICS.value, []).extend(content)
Expand Down

0 comments on commit 032b5f4

Please sign in to comment.