Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix failing tests due to auto_combine deprecation #324

Merged
merged 12 commits into from
Sep 25, 2019
18 changes: 2 additions & 16 deletions aospy/automate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,9 @@ def _merge_dicts(*dict_args):
return result


def _input_func_py2_py3():
"""Find function for reading user input that works on Python 2 and 3.

See e.g. http://stackoverflow.com/questions/21731043
"""
try:
input = raw_input
except NameError:
import builtins
input = builtins.input
return input


def _user_verify(input_func=_input_func_py2_py3(),
prompt='Perform these computations? [y/n] '):
def _user_verify(prompt='Perform these computations? [y/n] '):
"""Prompt the user for verification."""
if not input_func(prompt).lower()[0] == 'y':
if not input(prompt).lower()[0] == 'y':
raise AospyException('Execution cancelled by user.')


Expand Down
65 changes: 18 additions & 47 deletions aospy/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
import pprint
import warnings

import numpy as np
import xarray as xr
Expand All @@ -11,7 +10,6 @@
ETA_STR,
GRID_ATTRS,
TIME_STR,
TIME_BOUNDS_STR,
)
from .utils import times, io

Expand Down Expand Up @@ -112,6 +110,7 @@ def set_grid_attrs_as_coords(ds):
-------
Dataset
Dataset with grid attributes set as coordinates

"""
grid_attrs_in_ds = set(GRID_ATTRS.keys()).intersection(
set(ds.coords) | set(ds.data_vars))
Expand All @@ -130,6 +129,7 @@ def _maybe_cast_to_float64(da):
Returns
-------
DataArray

"""
if da.dtype == np.float32:
logging.warning('Datapoints were stored using the np.float32 datatype.'
Expand Down Expand Up @@ -162,6 +162,7 @@ def _sel_var(ds, var, upcast_float32=True):
------
KeyError
If the variable is not in the Dataset

"""
for name in var.names:
try:
Expand All @@ -176,46 +177,6 @@ def _sel_var(ds, var, upcast_float32=True):
raise LookupError(msg)


def _prep_time_data(ds):
"""Prepare time coordinate information in Dataset for use in aospy.

1. If the Dataset contains a time bounds coordinate, add attributes
representing the true beginning and end dates of the time interval used
to construct the Dataset
2. If the Dataset contains a time bounds coordinate, overwrite the time
coordinate values with the averages of the time bounds at each timestep
3. Decode the times into np.datetime64 objects for time indexing

Parameters
----------
ds : Dataset
Pre-processed Dataset with time coordinate renamed to
internal_names.TIME_STR

Returns
-------
Dataset
The processed Dataset
"""
ds = times.ensure_time_as_index(ds)
if TIME_BOUNDS_STR in ds:
ds = times.ensure_time_avg_has_cf_metadata(ds)
ds[TIME_STR] = times.average_time_bounds(ds)
else:
logging.warning("dt array not found. Assuming equally spaced "
"values in time, even though this may not be "
"the case")
ds = times.add_uniform_time_weights(ds)
# Suppress enable_cftimeindex is a no-op warning; we'll keep setting it for
# now to maintain backwards compatibility for older xarray versions.
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
with xr.set_options(enable_cftimeindex=True):
ds = xr.decode_cf(ds, decode_times=True, decode_coords=False,
mask_and_scale=True)
return ds


def _load_data_from_disk(file_set, preprocess_func=lambda ds: ds,
data_vars='minimal', coords='minimal',
grid_attrs=None, **kwargs):
Expand Down Expand Up @@ -243,14 +204,21 @@ def _load_data_from_disk(file_set, preprocess_func=lambda ds: ds,
Returns
-------
Dataset

"""
apply_preload_user_commands(file_set)
func = _preprocess_and_rename_grid_attrs(preprocess_func, grid_attrs,
**kwargs)
return xr.open_mfdataset(file_set, preprocess=func, concat_dim=TIME_STR,
decode_times=False, decode_coords=False,
mask_and_scale=True, data_vars=data_vars,
coords=coords)
return xr.open_mfdataset(
file_set,
preprocess=func,
combine='by_coords',
spencerkclark marked this conversation as resolved.
Show resolved Hide resolved
decode_times=False,
decode_coords=False,
mask_and_scale=True,
data_vars=data_vars,
coords=coords,
)


def apply_preload_user_commands(file_set, cmd=io.dmget):
Expand All @@ -259,6 +227,7 @@ def apply_preload_user_commands(file_set, cmd=io.dmget):
For example, on the NOAA Geophysical Fluid Dynamics Laboratory
computational cluster, data that is saved on their tape archive
must be accessed via a `dmget` (or `hsmget`) command before being used.

"""
if cmd is not None:
cmd(file_set)
Expand Down Expand Up @@ -301,6 +270,7 @@ def load_variable(self, var=None, start_date=None, end_date=None,
-------
da : DataArray
DataArray for the specified variable, date range, and interval in

"""
file_set = self._generate_file_set(var=var, start_date=start_date,
end_date=end_date, **DataAttrs)
Expand All @@ -310,7 +280,7 @@ def load_variable(self, var=None, start_date=None, end_date=None,
time_offset=time_offset, grid_attrs=grid_attrs, **DataAttrs
)
if var.def_time:
ds = _prep_time_data(ds)
ds = times.prep_time_data(ds)
start_date = times.maybe_convert_to_index_date_type(
ds.indexes[TIME_STR], start_date)
end_date = times.maybe_convert_to_index_date_type(
Expand All @@ -330,6 +300,7 @@ def _load_or_get_from_model(self, var, start_date=None, end_date=None,
Supports both access of grid attributes either through the DataLoader
or through an optionally-provided Model object. Defaults to using
the version found in the DataLoader first.

"""
grid_attrs = None if model is None else model.grid_attrs

Expand Down
3 changes: 2 additions & 1 deletion aospy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def _get_grid_files(self):
try:
ds = xr.open_dataset(path, decode_times=False)
except (TypeError, AttributeError):
ds = xr.open_mfdataset(path, decode_times=False).load()
ds = xr.open_mfdataset(path, decode_times=False,
combine='by_coords').load()
except (RuntimeError, OSError) as e:
msg = str(e) + ': {}'.format(path)
raise RuntimeError(msg)
Expand Down
68 changes: 68 additions & 0 deletions aospy/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""pytest conftest.py file for sharing fixtures across modules."""
import datetime

from cftime import DatetimeNoLeap
import numpy as np
import pytest
import xarray as xr

from aospy.internal_names import (
LON_STR,
TIME_STR,
TIME_BOUNDS_STR,
BOUNDS_STR,
)


_DATE_RANGES = {
'datetime': (datetime.datetime(2000, 1, 1),
datetime.datetime(2002, 12, 31)),
'datetime64': (np.datetime64('2000-01-01'),
np.datetime64('2002-12-31')),
'cftime': (DatetimeNoLeap(2000, 1, 1),
DatetimeNoLeap(2002, 12, 31)),
'str': ('2000', '2002')
}


@pytest.fixture()
def alt_lat_str():
return 'LATITUDE'


@pytest.fixture()
def var_name():
return 'a'


@pytest.fixture()
def ds_with_time_bounds(alt_lat_str, var_name):
time_bounds = np.array([[0, 31], [31, 59], [59, 90]])
bounds = np.array([0, 1])
time = np.array([15, 46, 74])
data = np.zeros((3, 1, 1))
lat = [0]
lon = [0]
ds = xr.DataArray(data,
coords=[time, lat, lon],
dims=[TIME_STR, alt_lat_str, LON_STR],
name=var_name).to_dataset()
ds[TIME_BOUNDS_STR] = xr.DataArray(time_bounds,
coords=[time, bounds],
dims=[TIME_STR, BOUNDS_STR],
name=TIME_BOUNDS_STR)
units_str = 'days since 2000-01-01 00:00:00'
ds[TIME_STR].attrs['units'] = units_str
ds[TIME_BOUNDS_STR].attrs['units'] = units_str
return ds


@pytest.fixture()
def ds_inst(ds_with_time_bounds):
inst_time = np.array([3, 6, 9])
inst_units_str = 'hours since 2000-01-01 00:00:00'
ds_inst = ds_with_time_bounds.copy()
ds_inst.drop(TIME_BOUNDS_STR)
ds_inst[TIME_STR].values = inst_time
ds_inst[TIME_STR].attrs['units'] = inst_units_str
return ds_inst
67 changes: 36 additions & 31 deletions aospy/test/test_automate.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,46 @@
from multiprocessing import cpu_count
from os.path import isfile
import shutil
import sys
import itertools
from unittest import mock

import distributed
import pytest

from aospy import Var, Proj
from aospy.automate import (_get_attr_by_tag, _permuted_dicts_of_specs,
_get_all_objs_of_type, _merge_dicts,
_input_func_py2_py3, AospyException,
_user_verify, CalcSuite, _MODELS_STR, _RUNS_STR,
_VARIABLES_STR, _REGIONS_STR,
_compute_or_skip_on_error, submit_mult_calcs,
_n_workers_for_local_cluster,
_prune_invalid_time_reductions)
from aospy.automate import (
_user_verify,
_MODELS_STR,
_RUNS_STR,
_VARIABLES_STR,
_REGIONS_STR,
_compute_or_skip_on_error,
_get_all_objs_of_type,
_get_attr_by_tag,
_merge_dicts,
_n_workers_for_local_cluster,
_permuted_dicts_of_specs,
_prune_invalid_time_reductions,
AospyException,
CalcSuite,
submit_mult_calcs,
)
from .data.objects import examples as lib
from .data.objects.examples import (
example_proj, example_model, example_run, var_not_time_defined,
condensation_rain, convection_rain, precip, ps, sphum, globe, sahel, bk,
p, dp
example_proj,
example_model,
example_run,
var_not_time_defined,
condensation_rain,
convection_rain,
precip,
ps,
sphum,
globe,
sahel,
bk,
p,
dp,
)


Expand Down Expand Up @@ -128,19 +148,12 @@ def test_merge_dicts():
assert expected == _merge_dicts(dict1, dict2, dict3, dict4)


def test_input_func_py2_py3():
result = _input_func_py2_py3()
if sys.version.startswith('3'):
import builtins
assert result is builtins.input
elif sys.version.startswith('2'):
assert result is raw_input # noqa: F821


def test_user_verify():
with mock.patch('builtins.input', return_value='YES'):
_user_verify()
with pytest.raises(AospyException):
_user_verify(lambda x: 'no')
_user_verify(lambda x: 'YES')
with mock.patch('builtins.input', return_value='no'):
_user_verify()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -235,8 +248,6 @@ def assert_calc_files_exist(calcs, write_to_tar, dtypes_out_time):
assert not isfile(calc.path_tar_out)


@pytest.mark.skipif(sys.version.startswith('2'),
reason='https://github.com/spencerahill/aospy/issues/259')
@pytest.mark.parametrize(
('exec_options'),
[dict(parallelize=True, write_to_tar=False),
Expand All @@ -251,8 +262,6 @@ def test_submit_mult_calcs_external_client(calcsuite_init_specs_single_calc,
calcsuite_init_specs_single_calc['output_time_regional_reductions'])


@pytest.mark.skipif(sys.version.startswith('2'),
reason='https://github.com/spencerahill/aospy/issues/259')
@pytest.mark.parametrize(
('exec_options'),
[dict(parallelize=False, write_to_tar=False),
Expand All @@ -278,8 +287,6 @@ def test_submit_mult_calcs_no_calcs(calcsuite_init_specs):
submit_mult_calcs(specs)


@pytest.mark.skipif(sys.version.startswith('2'),
reason='https://github.com/spencerahill/aospy/issues/259')
@pytest.mark.parametrize(
('exec_options'),
[dict(parallelize=True, write_to_tar=False),
Expand All @@ -294,8 +301,6 @@ def test_submit_two_calcs_external_client(calcsuite_init_specs_two_calcs,
calcsuite_init_specs_two_calcs['output_time_regional_reductions'])


@pytest.mark.skipif(sys.version.startswith('2'),
reason='https://github.com/spencerahill/aospy/issues/259')
@pytest.mark.parametrize(
('exec_options'),
[dict(parallelize=False, write_to_tar=False),
Expand Down
Loading