Skip to content

Commit

Permalink
ENH: xarray grid output
Browse files Browse the repository at this point in the history
  • Loading branch information
syedhamidali committed Oct 20, 2023
1 parent 3acf03c commit 8ad98a9
Showing 1 changed file with 91 additions and 80 deletions.
171 changes: 91 additions & 80 deletions pyart/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,106 +315,117 @@ def to_xarray(self):
in a one dimensional array.
"""

lon, lat = self.get_point_longitude_latitude()
z = self.z["data"]
y = self.y["data"]
x = self.x["data"]

time = np.array([num2date(self.time["data"][0], self.time["units"])])
time = np.array(
[num2date(self.time["data"][0], units=self.time["units"])],
dtype="datetime64[ns]",
)

ds = xarray.Dataset()
for field in list(self.fields.keys()):
field_data = self.fields[field]["data"]
for field, field_info in self.fields.items():
field_data = field_info["data"]
data = xarray.DataArray(
np.ma.expand_dims(field_data, 0),
dims=("time", "z", "y", "x"),
coords={
"time": (["time"], time),
"z": (["z"], z),
"time": time,
"z": z,
"lat": (["y", "x"], lat),
"lon": (["y", "x"], lon),
"y": (["y"], y),
"x": (["x"], x),
"y": y,
"x": x,
},
)
for meta in list(self.fields[field].keys()):

for meta, value in field_info.items():
if meta != "data":
data.attrs.update({meta: self.fields[field][meta]})
data.attrs.update({meta: value})

ds[field] = data

ds.lon.attrs = [
("long_name", "longitude of grid cell center"),
("units", "degree_E"),
("standard_name", "Longitude"),
]
ds.lat.attrs = [
("long_name", "latitude of grid cell center"),
("units", "degree_N"),
("standard_name", "Latitude"),
]

ds.z.attrs = get_metadata("z")
ds.y.attrs = get_metadata("y")
ds.x.attrs = get_metadata("x")

ds.z.encoding["_FillValue"] = None
ds.lat.encoding["_FillValue"] = None
ds.lon.encoding["_FillValue"] = None

# Delayed import
from ..io.grid_io import _make_coordinatesystem_dict

ds["ProjectionCoordinateSystem"] = xarray.DataArray(
data=np.array(1, dtype="int32"),
dims=None,
attrs=_make_coordinatesystem_dict(self),
)
ds.lon.attrs = [
("long_name", "longitude of grid cell center"),
("units", "degree_E"),
("standard_name", "Longitude"),
]
ds.lat.attrs = [
("long_name", "latitude of grid cell center"),
("units", "degree_N"),
("standard_name", "Latitude"),
]

ds.z.attrs = get_metadata("z")
ds.y.attrs = get_metadata("y")
ds.x.attrs = get_metadata("x")

for attr in [ds.z, ds.lat, ds.lon]:
attr.encoding["_FillValue"] = None

# Delayed import
from ..io.grid_io import _make_coordinatesystem_dict

ds["ProjectionCoordinateSystem"] = xarray.DataArray(
data=np.array(1, dtype="int32"),
attrs=_make_coordinatesystem_dict(self),
)

if self.origin_latitude is not None:
ds["origin_latitude"] = xarray.DataArray(
np.ma.expand_dims(self.origin_latitude["data"][0], 0),
dims=("time"),
attrs=get_metadata("origin_latitude"),
)

if self.origin_longitude is not None:
ds["origin_longitude"] = xarray.DataArray(
np.ma.expand_dims(self.origin_longitude["data"][0], 0),
dims=("time"),
attrs=get_metadata("origin_longitude"),
)

if self.origin_altitude is not None:
ds["origin_altitude"] = xarray.DataArray(
np.ma.expand_dims(self.origin_altitude["data"][0], 0),
dims=("time"),
attrs=get_metadata("origin_altitude"),
)

if self.radar_altitude is not None:
ds["radar_altitude"] = xarray.DataArray(
np.ma.expand_dims(self.radar_altitude["data"][0], 0),
dims=("nradar"),
attrs=get_metadata("radar_altitude"),
)

if self.radar_latitude is not None:
ds["radar_latitude"] = xarray.DataArray(
np.ma.expand_dims(self.radar_latitude["data"][0], 0),
dims=("nradar"),
attrs=get_metadata("radar_latitude"),
)

if self.radar_longitude is not None:
ds["radar_longitude"] = xarray.DataArray(
np.ma.expand_dims(self.radar_longitude["data"][0], 0),
dims=("nradar"),
attrs=get_metadata("radar_longitude"),
)

ds.close()
# write the projection dictionary as a scalar
projection = self.projection.copy()
# NetCDF does not support boolean attribute, covert to string
if "_include_lon_0_lat_0" in projection:
include = projection["_include_lon_0_lat_0"]
projection["_include_lon_0_lat_0"] = ["false", "true"][include]
ds["projection"] = xarray.DataArray(
data=np.array(1, dtype="int32"),
dims=None,
attrs=projection,
)

for attr_name in [
"origin_latitude",
"origin_longitude",
"origin_altitude",
"radar_altitude",
"radar_latitude",
"radar_longitude",
"radar_time",
]:
if hasattr(self, attr_name):
attr_data = getattr(self, attr_name)
if attr_data is not None:
if "radar_time" not in attr_name:
attr_value = np.ma.expand_dims(attr_data["data"][0], 0)
else:
attr_value = [
np.array(
num2date(
attr_data["data"][0],
units=attr_data["units"],
),
dtype="datetime64[ns]",
)
]
dims = ("nradar",)
ds[attr_name] = xarray.DataArray(
attr_value, dims=dims, attrs=get_metadata(attr_name)
)

if "radar_time" in ds.variables:
ds.radar_time.attrs.pop("calendar")

if self.radar_name is not None:
radar_name = self.radar_name["data"][0]
ds["radar_name"] = xarray.DataArray(
np.array([b"".join(radar_name)], dtype="S4"),
dims=("nradar"),
attrs=get_metadata("radar_name"),
)
ds.attrs = self.metadata
ds.close()
return ds

def add_field(self, field_name, field_dict, replace_existing=False):
Expand Down

0 comments on commit 8ad98a9

Please sign in to comment.