Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Allow passing os.PathLike as an alternative to strings to represent node paths #282

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from xarray.core.merge import CoercibleValue
from xarray.core.types import ErrorOptions

from datatree.treenode import T_PathLike

# """
# DEVELOPERS' NOTE
# ----------------
Expand All @@ -76,9 +78,6 @@
# """


T_Path = Union[str, NodePath]


def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset:
if isinstance(data, DataArray):
ds = data.to_dataset()
Expand Down Expand Up @@ -848,7 +847,7 @@ def get(
else:
return default

def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
def __getitem__(self: DataTree, key: T_PathLike) -> DataTree | DataArray:
"""
Access child nodes, variables, or coordinates stored anywhere in this tree.

Expand Down Expand Up @@ -903,7 +902,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None:

def __setitem__(
self,
key: str,
key: T_PathLike,
value: Any,
) -> None:
"""
Expand Down Expand Up @@ -1034,7 +1033,7 @@ def drop_nodes(
@classmethod
def from_dict(
cls,
d: MutableMapping[str, Dataset | DataArray | DataTree | None],
d: MutableMapping[T_PathLike, Dataset | DataArray | DataTree | None],
name: Optional[str] = None,
) -> DataTree:
"""
Expand Down Expand Up @@ -1442,7 +1441,7 @@ def merge(self, datatree: DataTree) -> DataTree:
"""Merge all the leaves of a second DataTree into this one."""
raise NotImplementedError

def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree:
def merge_child_nodes(self, *paths, new_path: T_PathLike) -> DataTree:
"""Merge a set of child nodes into a single new node."""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions datatree/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Tuple
from typing import TYPE_CHECKING, Callable, Tuple, Union

from xarray import DataArray, Dataset

Expand Down Expand Up @@ -228,7 +228,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
original_root_path = first_tree.path
result_trees = []
for i in range(num_return_values):
out_tree_contents = {}
out_tree_contents: dict[str, Union[None, Dataset]] = {}
for n in first_tree.subtree:
p = n.path
if p in out_data_objects.keys():
Expand Down
14 changes: 8 additions & 6 deletions datatree/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
from xarray.core.utils import Frozen, is_dict_like

if TYPE_CHECKING:
from os import PathLike

from xarray.core.types import T_DataArray

T_PathLike = Union[str, PathLike]


class InvalidTreeError(Exception):
"""Raised when user attempts to create an invalid tree in some way."""
Expand Down Expand Up @@ -445,14 +449,13 @@ def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]:

# TODO `._walk` method to be called by both `_get_item` and `_set_item`

def _get_item(self: Tree, path: str | NodePath) -> Union[Tree, T_DataArray]:
def _get_item(self: Tree, path: T_PathLike) -> Union[Tree, T_DataArray]:
"""
Returns the object lying at the given path.

Raises a KeyError if there is no object at the given path.
"""
if isinstance(path, str):
path = NodePath(path)
path = NodePath(path)

if path.root:
current_node = self.root
Expand Down Expand Up @@ -487,7 +490,7 @@ def _set(self: Tree, key: str, val: Tree) -> None:

def _set_item(
self: Tree,
path: str | NodePath,
path: T_PathLike,
item: Union[Tree, T_DataArray],
new_nodes_along_path: bool = False,
allow_overwrite: bool = True,
Expand All @@ -513,8 +516,7 @@ def _set_item(
If node cannot be reached, and new_nodes_along_path=False.
Or if a node already exists at the specified path, and allow_overwrite=False.
"""
if isinstance(path, str):
path = NodePath(path)
path = NodePath(path)

if not path.name:
raise ValueError("Can't set an item under a path which has no name")
Expand Down
3 changes: 3 additions & 0 deletions docs/source/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ v0.0.14 (unreleased)
New Features
~~~~~~~~~~~~

- Allow passing :py:class:`os.PathLike` objects as paths to nodes in addition to strings. (:pull:`282`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Breaking changes
~~~~~~~~~~~~~~~~

Expand Down
Loading