Skip to content

Commit

Permalink
Merge pull request #1461 from knutfrode/dev
Browse files Browse the repository at this point in the history
[run-ex] Dynamic detection of dimension order in generic reader also …
  • Loading branch information
knutfrode authored Dec 18, 2024
2 parents d66264c + f2e95d8 commit a29bed0
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 62 deletions.
11 changes: 10 additions & 1 deletion opendrift/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,16 @@ def open_dataset_opendrift(source, zarr_storage_options=None, open_mfdataset_opt

return ds


def datetime_from_variable(var):
import pandas as pd
try:
return pd.to_datetime(var).to_pydatetime()
except:
logger.warning('Could not decode time with Pandas')
datetimeindex = var.to_index().to_datetimeindex()
times = pd.to_datetime(datetimeindex).to_pydatetime()
logger.info('Decoded time through datetimeindex')
return times

def open_mfdataset_overlap(url_base, time_series=None, start_time=None, end_time=None, freq=None, timedim='time'):
if time_series is None:
Expand Down
10 changes: 8 additions & 2 deletions opendrift/readers/basereader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,14 @@ def plot(self, variable=None, vmin=None, vmax=None, time=None,
if time is None:
time = self.start_time
if variable is not None:
rx = np.array([self.xmin, self.xmax])
ry = np.array([self.ymin, self.ymax])
if self.global_coverage():
ax.set_global()
# Spaced by 500 to avoid splitting in west/east blocks
rx = np.linspace(self.xmin, self.xmax, num=int(np.ceil(self.numx/500))+1)
ry = np.linspace(self.ymin, self.ymax, num=len(rx))
else:
rx = np.array([self.xmin, self.xmax])
ry = np.array([self.ymin, self.ymax])
if variable in self.derived_variables:
data = self.get_variables(self.derived_variables[variable], time, rx, ry)
self.__calculate_derived_environment_variables__(data)
Expand Down
6 changes: 4 additions & 2 deletions opendrift/readers/basereader/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def _get_variables_interpolated_(self, variables, profiles, profiles_depth,
)
self.var_block_before[blockvars_before] = \
ReaderBlock(reader_data_dict,
interpolation_horizontal=self.interpolation)
interpolation_horizontal=self.interpolation,
wrap_x=self.global_coverage())
try:
len_z = len(self.var_block_before[blockvars_before].z)
except:
Expand All @@ -304,7 +305,8 @@ def _get_variables_interpolated_(self, variables, profiles, profiles_depth,
self.var_block_after[blockvars_after] = \
ReaderBlock(
reader_data_dict,
interpolation_horizontal=self.interpolation)
interpolation_horizontal=self.interpolation,
wrap_x=self.global_coverage())
try:
len_z = len(self.var_block_after[blockvars_after].z)
except:
Expand Down
23 changes: 22 additions & 1 deletion opendrift/readers/interpolation/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class ReaderBlock():

def __init__(self, data_dict,
interpolation_horizontal='linearNDFast',
interpolation_vertical='linear'):
interpolation_vertical='linear',
wrap_x=False):

# Make pointers to data values, for convenience
self.x = data_dict['x']
Expand All @@ -29,6 +30,14 @@ def __init__(self, data_dict,
del self.data_dict['z']
except:
self.z = None
self.wrap_x = wrap_x # For global readers where longitude wraps at 360
if wrap_x is True:
if self.x.min() < 180:
logger.debug('Shifting reader block longitudes to -180 to 180')
self.x = np.mod(self.x+180, 360) - 180
elif self.x.max() > 360:
logger.debug('Shifting reader block longitudes to 0 to 360')
self.x = np.mod(self.x, 360)

# Mask any extremely large values, e.g. if missing netCDF _Fill_value
filled_variables = set()
Expand Down Expand Up @@ -78,6 +87,11 @@ def __init__(self, data_dict,

def _initialize_interpolator(self, x, y, z=None):
logger.debug('Initialising interpolator.')
if self.wrap_x is True:
if self.x.min() > 0:
x = np.mod(x, 360) # Shift x/lons to 0-360
else:
x = np.mod(x + 180, 360) - 180 # Shift x/lons to -180-180
self.interpolator2d = self.Interpolator2DClass(self.x, self.y, x, y)
if self.z is not None and len(np.atleast_1d(self.z)) > 1:
self.interpolator1d = self.Interpolator1DClass(self.z, z)
Expand Down Expand Up @@ -143,6 +157,13 @@ def _interpolate_horizontal_layers(self, data, nearest=False):
def covers_positions(self, x, y, z=None):
'''Check if given positions are covered by this reader block.'''

if self.wrap_x is True:
if self.x.min() < 180:
logger.debug('Shifting longitudes to -180 to 180')
x = np.mod(x+180, 360) - 180
elif self.x.max() > 360:
logger.debug('Shifting longitudes to 0 to 360')
x = np.mod(x, 360)
indices = np.where((x >= self.x.min()) & (x <= self.x.max()) &
(y >= self.y.min()) & (y <= self.y.max()))[0]
if len(indices) == len(x):
Expand Down
140 changes: 84 additions & 56 deletions opendrift/readers/reader_netCDF_CF_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pandas as pd
import xarray as xr
from opendrift.readers.basereader import BaseReader, StructuredReader
from opendrift.readers import open_dataset_opendrift
from opendrift.readers import open_dataset_opendrift, datetime_from_variable


class Reader(StructuredReader, BaseReader):
Expand Down Expand Up @@ -228,7 +228,8 @@ def __init__(self, filename=None, zarr_storage_options=None, name=None, proj4=No
# self.times = [pd.to_datetime(str(d)) for d in time]
# else:
# self.times = num2date(time, time_units, calendar=calendar)
self.times = pd.to_datetime(var).to_pydatetime()
#self.times = pd.to_datetime(var).to_pydatetime()
self.times = datetime_from_variable(var)
self.start_time = self.times[0]
self.end_time = self.times[-1]
if len(self.times) > 1:
Expand Down Expand Up @@ -374,11 +375,11 @@ def __init__(self, filename=None, zarr_storage_options=None, name=None, proj4=No

self.variables = list(self.variable_mapping.keys())

# Workaround for datasets with unnecessary ensemble dimension for static variables
for vn, va in self.variable_mapping.items():
if vn == 'sea_floor_depth_below_sea_level':
var = self.Dataset.variables[va]
if 'ensemble_member' in var.dims:
# Workaround for datasets with unnecessary ensemble dimension for static variables
logger.info(f'Removing ensemble dimension from {vn}')
var = var.isel(ensemble_member=0).squeeze()
self.Dataset[va] = var
Expand Down Expand Up @@ -420,30 +421,33 @@ def get_variables(self, requested_variables, time=None,
clipped = self.clipped
else: clipped = 0

if self.global_coverage():
buffer = self.buffer # Adding buffer, to cover also future positions of elements
indy = np.floor(np.abs(y-self.y[0])/self.delta_y-clipped).astype(int) + clipped
indy = np.arange(np.max([0, indy.min()-buffer]),
np.min([indy.max()+buffer, self.numy]))

if self.global_coverage(): # Treatment of cyclic longitudes (x-coordinate)
if self.lon_range() == '0to360':
x = np.mod(x, 360) # Shift x/lons to 0-360
elif self.lon_range() == '-180to180':
x = np.mod(x + 180, 360) - 180 # Shift x/lons to -180-180
indx = np.floor(np.abs(x-self.x[0])/self.delta_x-clipped).astype(int) + clipped
indy = np.floor(np.abs(y-self.y[0])/self.delta_y-clipped).astype(int) + clipped
buffer = self.buffer # Adding buffer, to cover also future positions of elements
indy = np.arange(np.max([0, indy.min()-buffer]),
np.min([indy.max()+buffer, self.numy]))
indx = np.arange(indx.min()-buffer, indx.max()+buffer+1)

if self.global_coverage() and indx.min() < 0 and indx.max() > 0 and indx.max() < self.numx:
logger.debug('Requested data block is not continuous in file'+
', must read two blocks and concatenate.')
indx_left = indx[indx<0] + self.numx # Shift to positive indices
indx_right = indx[indx>=0]
if indx_right.max() >= indx_left.min(): # Avoid overlap
indx_right = np.arange(indx_right.min(), indx_left.min())
continuous = False
else:
continuous = True
indx = np.arange(np.max([0, indx.min()]),
np.min([indx.max(), self.numx]))

split = False
if self.global_coverage(): # Check if need to split in two blocks
uniqx = np.unique(indx)
diff_xind = np.diff(uniqx)
# We split if >800 pixels between left/west and right/east blocks
if len(diff_xind)>1 and diff_xind.max() > np.minimum(800, 0.6*self.numx):
logger.debug('Requested data block crosses lon-border, reading and concatinating two parts')
split = True
splitind = np.argmax(diff_xind)
indx_left = np.arange(0, uniqx[splitind] + buffer)
indx_right = np.arange(uniqx[splitind+1] - buffer, self.numx)
indx = np.concatenate((indx_right, indx_left))
if split is False:
indx = np.arange(np.maximum(0, indx.min()-buffer),
np.minimum(indx.max()+buffer+1, self.numx))

variables = {}

Expand All @@ -458,9 +462,19 @@ def get_variables(self, requested_variables, time=None,
var = self.Dataset.variables[self.variable_mapping[par]]

ensemble_dim = None
if continuous is True:
dimindices = {'x': indx, 'y': indy, 'time': indxTime, 'z': indz}
dimorder = list(var.dims)
xnum = dimorder.index(self.dimensions['x'])
ynum = dimorder.index(self.dimensions['y'])
if xnum < ynum:
# We must have y before x, since returning numpy arrays and not Xarrays
logger.debug(f'Swapping order of x-y dimensions for {par}')
dimorder[xnum] = self.dimensions['y']
dimorder[ynum] = self.dimensions['x']
var = var.permute_dims(*dimorder)

if split is False:
if True: # new dynamic way
dimindices = {'x': indx, 'y': indy, 'time': indxTime, 'z': indz}
subset = {vdim:dimindices[dim] for dim,vdim in self.dimensions.items() if vdim in var.dims}
variables[par] = var.isel(subset)
# Remove any unknown dimensions
Expand All @@ -470,39 +484,53 @@ def get_variables(self, requested_variables, time=None,
variables[par] = variables[par].squeeze(dim=dim)
if self.ensemble_dimension is not None and self.ensemble_dimension in variables[par].dims:
ensemble_dim = 0 # hardcoded, may not work for MEPS
else: # old hardcoded way
if var.ndim == 2:
variables[par] = var[indy, indx]
elif var.ndim == 3:
variables[par] = var[indxTime, indy, indx]
elif var.ndim == 4:
variables[par] = var[indxTime, indz, indy, indx]
elif var.ndim == 5: # Ensemble data
variables[par] = var[indxTime, indz, indrealization, indy, indx]
ensemble_dim = 0 # Hardcoded ensemble dimension for now
else:
raise Exception('Wrong dimension of variable: ' +
self.variable_mapping[par])
#else: # old hardcoded way, to be removed
# if var.ndim == 2:
# variables[par] = var[indy, indx]
# elif var.ndim == 3:
# variables[par] = var[indxTime, indy, indx]
# elif var.ndim == 4:
# variables[par] = var[indxTime, indz, indy, indx]
# elif var.ndim == 5: # Ensemble data
# variables[par] = var[indxTime, indz, indrealization, indy, indx]
# ensemble_dim = 0 # Hardcoded ensemble dimension for now
# else:
# raise Exception('Wrong dimension of variable: ' +
# self.variable_mapping[par])
# The below should also be updated to dynamic subsetting
else: # We need to read left and right parts separately
if var.ndim == 2:
left = var[indy, indx_left]
right = var[indy, indx_right]
variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])
elif var.ndim == 3:
left = var[indxTime, indy, indx_left]
right = var[indxTime, indy, indx_right]
variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])
elif var.ndim == 4:
left = var[indxTime, indz, indy, indx_left]
right = var[indxTime, indz, indy, indx_right]
variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])
elif var.ndim == 5: # Ensemble data
left = var[indxTime, indz, indrealization,
indy, indx_left]
right = var[indxTime, indz, indrealization,
indy, indx_right]
variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])
d_left = dimindices.copy()
d_right = dimindices.copy()
d_left.update({'x': indx_left})
d_right.update({'x': indx_right})
subset_left = {vdim:d_left[dim] for dim,vdim in self.dimensions.items()
if vdim in var.dims}
subset_right = {vdim:d_right[dim] for dim,vdim in self.dimensions.items()
if vdim in var.dims}
left = var.isel(subset_left)
right = var.isel(subset_right)

#if var.ndim == 2:
# left = var[indy, indx_left]
# right = var[indy, indx_right]
# variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])
#elif var.ndim == 3:
# left = var[indxTime, indy, indx_left]
# right = var[indxTime, indy, indx_right]
# variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])
#elif var.ndim == 4:
# left = var[indxTime, indz, indy, indx_left]
# right = var[indxTime, indz, indy, indx_right]
# variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])
#elif var.ndim == 5: # Ensemble data
# left = var[indxTime, indz, indrealization,
# indy, indx_left]
# right = var[indxTime, indz, indrealization,
# indy, indx_right]
# variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])

#variables[par] = xr.Variable.concat([left, right], dim=self.dimensions['x'])
variables[par] = xr.Variable.concat([right, left], dim=self.dimensions['x'])
variables[par] = np.ma.masked_invalid(variables[par])

# Mask values outside domain
Expand Down Expand Up @@ -532,7 +560,7 @@ def get_variables(self, requested_variables, time=None,
if self.projected is True:
variables['x'] = self.x[indx]
variables['y'] = self.y[indy]
if continuous is False and variables['x'][0] > variables['x'][-1]:
if split is True and variables['x'][0] > variables['x'][-1]:
# We need to shift so that x-coordinate (longitude) is continous
if self.lon_range() == '-180to180':
variables['x'][variables['x']>0] -= 360
Expand Down

0 comments on commit a29bed0

Please sign in to comment.