Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for monthly-resolution predictions #172

Closed
wants to merge 5 commits into from
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 69 additions & 26 deletions climpred/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,49 @@ def compute_perfect_model(ds, control, metric='rmse', comparison='m2e'):
return res


def _slice_to_correct_time(forecast, reference, resolution='Y', nlags=None):
"""Reduces the forecast and reference object to a compatable window of time based
on their minimum and maximum times and how many lags are being computed.

Args:
forecast (xarray object):
Prediction ensemble
reference (xarray object):
Reference to compare predictions to
resolution (str):
Temporal resolution of the predictions
'Y': annual
'M': monthly
nlags (int):
Number of lags being computed for the forecast

Returns:
Post-processed forecast and reference, trimmed to the appropriate time lengths.
"""
if resolution not in ['Y', 'M']:
raise ValueError(
f"Your resolution of {resolution} is not 'Y' (annual) or 'M' (monthly)."
)
if nlags is None:
nlags = forecast.lead.size
# take only inits for which we have references at all leahind
imin = max(forecast.time.min(), reference.time.min())
imax = min(
forecast.time.max(),
reference.time.max().values.astype(f'datetime64[{resolution}]') - nlags,
)
# Some lags force this into a numpy array for some reason.
imax = xr.DataArray(imax).rename('time')
forecast = forecast.where(forecast.time <= imax, drop=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you could do:

forecast.sel(time=slice(imin, imax))

forecast = forecast.where(forecast.time >= imin, drop=True)
reference = reference.where(reference.time >= imin, drop=True)
return forecast, reference


@check_xarray([0, 1])
def compute_hindcast(hind, reference, metric='pearson_r', comparison='e2r'):
def compute_hindcast(
hind, reference, metric='pearson_r', comparison='e2r', resolution='Y'
):
"""
Compute a predictability skill score against some reference (hindcast,
assimilation, reconstruction, observations).
Expand All @@ -97,29 +138,34 @@ def compute_hindcast(hind, reference, metric='pearson_r', comparison='e2r'):
coefficients are for potential predictability. If the reference is
observations, the output correlation coefficients are actual skill.

Parameters
----------
hind (xarray object):
Expected to follow package conventions:
`time` : dim of initialization dates
`lead` : dim of lead time from those initializations
Additional dims can be lat, lon, depth.
reference (xarray object):
reference output/data over same time period.
metric (str):
Metric used in comparing the decadal prediction ensemble with the
reference.
comparison (str):
How to compare the decadal prediction ensemble to the reference.
* e2r : ensemble mean to reference (Default)
* m2r : each member to the reference
nlags (int): How many lags to compute skill/potential predictability out
to. Default: length of `lead` dim
Args:
hind (xarray object):
Expected to follow package conventions:
`time` : dim of initialization dates
`lead` : dim of lead time from those initializations
Additional dims can be lat, lon, depth.
reference (xarray object):
reference output/data over same time period.
metric (str):
Metric used in comparing the decadal prediction ensemble with the
reference.
comparison (str):
How to compare the decadal prediction ensemble to the reference.
* e2r : ensemble mean to reference (Default)
* m2r : each member to the reference
resolution (str):
Temporal resolution of the hindcast prediction
* 'Y': annual
* 'M': monthly

Returns:
skill (xarray object): Predictability with main dimension `lag`.
"""
nlags = hind.lead.size
if resolution not in ['Y', 'M']:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe when constants.py is added in my PR, add TEMPORAL_RES = ['Y', 'M']

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add this if ... to checks.py since I see it used in multiple places

raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe NotImplementedError

f"Your resolution of {resolution} is not 'Y' (annual) or 'M' (monthly)."
)

comparison = get_comparison_function(comparison)
_validate_hindcast_comparison(comparison)
Expand All @@ -129,19 +175,16 @@ def compute_hindcast(hind, reference, metric='pearson_r', comparison='e2r'):
forecast, reference = comparison(hind, reference)
# think in real time dimension: real time = init + lag
forecast = forecast.rename({'init': 'time'})
# take only inits for which we have references at all leahind
imin = max(forecast.time.min(), reference.time.min())
imax = min(forecast.time.max(), reference.time.max() - nlags)
forecast = forecast.where(forecast.time <= imax, drop=True)
forecast = forecast.where(forecast.time >= imin, drop=True)
reference = reference.where(reference.time >= imin, drop=True)
forecast, reference = _slice_to_correct_time(
forecast, reference, nlags=nlags, resolution=resolution
)

plag = []
# iterate over all leads (accounts for lead.min() in [0,1])
for i in forecast.lead.values:
# take lead year i timeseries and convert to real time
a = forecast.sel(lead=i).drop('lead')
a['time'] = [t + i for t in a.time.values]
a['time'] = [t + i for t in a.time.values.astype(f'datetime64[{resolution}]')]
# take real time reference of real time forecast years
b = reference.sel(time=a.time.values)
plag.append(metric(a, b, dim='time', comparison=comparison))
Expand Down