-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for monthly-resolution predictions #172
Changes from 4 commits
db55ebb
68fb5dc
881cc31
595b6d0
b420f0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,8 +87,49 @@ def compute_perfect_model(ds, control, metric='rmse', comparison='m2e'): | |
return res | ||
|
||
|
||
def _slice_to_correct_time(forecast, reference, resolution='Y', nlags=None): | ||
"""Reduces the forecast and reference object to a compatable window of time based | ||
on their minimum and maximum times and how many lags are being computed. | ||
|
||
Args: | ||
forecast (xarray object): | ||
Prediction ensemble | ||
reference (xarray object): | ||
Reference to compare predictions to | ||
resolution (str): | ||
Temporal resolution of the predictions | ||
'Y': annual | ||
'M': monthly | ||
nlags (int): | ||
Number of lags being computed for the forecast | ||
|
||
Returns: | ||
Post-processed forecast and reference, trimmed to the appropriate time lengths. | ||
""" | ||
if resolution not in ['Y', 'M']: | ||
raise ValueError( | ||
f"Your resolution of {resolution} is not 'Y' (annual) or 'M' (monthly)." | ||
) | ||
if nlags is None: | ||
nlags = forecast.lead.size | ||
# take only inits for which we have references at all leahind | ||
imin = max(forecast.time.min(), reference.time.min()) | ||
imax = min( | ||
forecast.time.max(), | ||
reference.time.max().values.astype(f'datetime64[{resolution}]') - nlags, | ||
) | ||
# Some lags force this into a numpy array for some reason. | ||
imax = xr.DataArray(imax).rename('time') | ||
forecast = forecast.where(forecast.time <= imax, drop=True) | ||
forecast = forecast.where(forecast.time >= imin, drop=True) | ||
reference = reference.where(reference.time >= imin, drop=True) | ||
return forecast, reference | ||
|
||
|
||
@check_xarray([0, 1]) | ||
def compute_hindcast(hind, reference, metric='pearson_r', comparison='e2r'): | ||
def compute_hindcast( | ||
hind, reference, metric='pearson_r', comparison='e2r', resolution='Y' | ||
): | ||
""" | ||
Compute a predictability skill score against some reference (hindcast, | ||
assimilation, reconstruction, observations). | ||
|
@@ -97,29 +138,34 @@ def compute_hindcast(hind, reference, metric='pearson_r', comparison='e2r'): | |
coefficients are for potential predictability. If the reference is | ||
observations, the output correlation coefficients are actual skill. | ||
|
||
Parameters | ||
---------- | ||
hind (xarray object): | ||
Expected to follow package conventions: | ||
`time` : dim of initialization dates | ||
`lead` : dim of lead time from those initializations | ||
Additional dims can be lat, lon, depth. | ||
reference (xarray object): | ||
reference output/data over same time period. | ||
metric (str): | ||
Metric used in comparing the decadal prediction ensemble with the | ||
reference. | ||
comparison (str): | ||
How to compare the decadal prediction ensemble to the reference. | ||
* e2r : ensemble mean to reference (Default) | ||
* m2r : each member to the reference | ||
nlags (int): How many lags to compute skill/potential predictability out | ||
to. Default: length of `lead` dim | ||
Args: | ||
hind (xarray object): | ||
Expected to follow package conventions: | ||
`time` : dim of initialization dates | ||
`lead` : dim of lead time from those initializations | ||
Additional dims can be lat, lon, depth. | ||
reference (xarray object): | ||
reference output/data over same time period. | ||
metric (str): | ||
Metric used in comparing the decadal prediction ensemble with the | ||
reference. | ||
comparison (str): | ||
How to compare the decadal prediction ensemble to the reference. | ||
* e2r : ensemble mean to reference (Default) | ||
* m2r : each member to the reference | ||
resolution (str): | ||
Temporal resolution of the hindcast prediction | ||
* 'Y': annual | ||
* 'M': monthly | ||
|
||
Returns: | ||
skill (xarray object): Predictability with main dimension `lag`. | ||
""" | ||
nlags = hind.lead.size | ||
if resolution not in ['Y', 'M']: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe when constants.py is added in my PR, add TEMPORAL_RES = ['Y', 'M'] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add this |
||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe NotImplementedError |
||
f"Your resolution of {resolution} is not 'Y' (annual) or 'M' (monthly)." | ||
) | ||
|
||
comparison = get_comparison_function(comparison) | ||
_validate_hindcast_comparison(comparison) | ||
|
@@ -129,19 +175,16 @@ def compute_hindcast(hind, reference, metric='pearson_r', comparison='e2r'): | |
forecast, reference = comparison(hind, reference) | ||
# think in real time dimension: real time = init + lag | ||
forecast = forecast.rename({'init': 'time'}) | ||
# take only inits for which we have references at all leahind | ||
imin = max(forecast.time.min(), reference.time.min()) | ||
imax = min(forecast.time.max(), reference.time.max() - nlags) | ||
forecast = forecast.where(forecast.time <= imax, drop=True) | ||
forecast = forecast.where(forecast.time >= imin, drop=True) | ||
reference = reference.where(reference.time >= imin, drop=True) | ||
forecast, reference = _slice_to_correct_time( | ||
forecast, reference, nlags=nlags, resolution=resolution | ||
) | ||
|
||
plag = [] | ||
# iterate over all leads (accounts for lead.min() in [0,1]) | ||
for i in forecast.lead.values: | ||
# take lead year i timeseries and convert to real time | ||
a = forecast.sel(lead=i).drop('lead') | ||
a['time'] = [t + i for t in a.time.values] | ||
a['time'] = [t + i for t in a.time.values.astype(f'datetime64[{resolution}]')] | ||
# take real time reference of real time forecast years | ||
b = reference.sel(time=a.time.values) | ||
plag.append(metric(a, b, dim='time', comparison=comparison)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if you could do:
forecast.sel(time=slice(imin, imax))