Skip to content

Commit

Permalink
Merge pull request #67 from chrishavlin/single_field_load_grid
Browse files Browse the repository at this point in the history
allow single field for load_grid
  • Loading branch information
chrishavlin authored Nov 3, 2023
2 parents 7e1dfea + c3d5318 commit 740e086
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

[![PyPI version](https://badge.fury.io/py/yt_xarray.svg)](https://badge.fury.io/py/yt_xarray)
[![Python Version](https://img.shields.io/pypi/pyversions/yt_xarray.svg?color=green)](https://python.org)
[![Tests](https://github.com/data-exp-lab/yt_xarray/actions/workflows/run-pytest-tests.yml/badge.svg)](https://github.com/data-exp-lab/yt_xarray/actions/workflows/run-pytest-tests.yml)
[![codecov](https://codecov.io/gh/data-exp-lab/yt_xarray/branch/main/graph/badge.svg)](https://codecov.io/gh/data-exp-lab/yt_xarray)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/data-exp-lab/yt_xarray/main.svg)](https://results.pre-commit.ci/latest/github/data-exp-lab/yt_xarray/main)

Expand Down
11 changes: 8 additions & 3 deletions yt_xarray/accessor/accessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import List, Optional
from typing import List, Optional, Union

import numpy as np
import xarray as xr
Expand All @@ -22,7 +22,7 @@ def __init__(self, xarray_obj):

def load_grid(
self,
fields: Optional[List[str]] = None,
fields: Optional[Union[str, List[str]]] = None,
geometry: str = None,
use_callable: bool = True,
sel_dict: Optional[dict] = None,
Expand All @@ -35,7 +35,7 @@ def load_grid(
Parameters
----------
fields : list[str]
fields : str, list[str]
list of fields to include. If None, will try to use all fields
geometry : str
Expand Down Expand Up @@ -66,6 +66,11 @@ def load_grid(
# might as well try!
fields = list(self._obj.data_vars)

if isinstance(fields, str):
fields = [
fields,
]

sel_info = _xr_to_yt.Selection(
self._obj,
fields=fields,
Expand Down
6 changes: 6 additions & 0 deletions yt_xarray/tests/test_accesor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,9 @@ def test_stretched_grid(use_callable):
cell_centers = (dims[:-1] + dims[1:]) / 2
dimvals = np.unique(ad[("index", dim)].d)
assert np.all(dimvals == cell_centers)


def test_load_single_field(ds_xr):
flds = "a_new_field_0"
ds_yt = ds_xr.yt.load_grid(flds)
_ = ds_yt.all_data()[("stream", flds)]

0 comments on commit 740e086

Please sign in to comment.