Skip to content
Snippets Groups Projects
Commit 6d3a9070 authored by Aaron Spring's avatar Aaron Spring :baby_symbol:
Browse files

Merge branch 'AS_RPS_mean_4_RPSS' into 'master'

use average RPS for RPSS

Closes aaron.spring/s2s-ai-challenge#50

See merge request !5
parents b57760f3 23ec4793
No related branches found
No related tags found
1 merge request!5use average RPS for RPSS
Pipeline #271464 passed with stage
in 4 minutes and 36 seconds
......@@ -11,7 +11,7 @@ def assert_predictions_2020(preds_test, exclude='week'):
from xarray.testing import assert_equal # doesnt care about attrs but checks coords
if isinstance(exclude, str):
exclude = [exclude]
# is dataset
assert isinstance(preds_test, xr.Dataset)
......@@ -19,12 +19,12 @@ def assert_predictions_2020(preds_test, exclude='week'):
if 'data_vars' not in exclude:
assert 'tp' in preds_test.data_vars
assert 't2m' in preds_test.data_vars
## coords
# ignore week coord if not dim
if 'week' in exclude and 'week' in preds_test.coords and 'week' not in preds_test.dims:
preds_test = preds_test.drop('week')
# forecast_time
if 'forecast_time' not in exclude:
d = pd.date_range(start='2020-01-02', freq='7D', periods=53)
......@@ -52,13 +52,13 @@ def assert_predictions_2020(preds_test, exclude='week'):
lead = [pd.Timedelta(f'{i} d') for i in [14, 28]]
lead_time = xr.DataArray(lead, dims='lead_time', coords={'lead_time': lead}, name='lead_time')
assert_equal(lead_time, preds_test['lead_time'])
# category
if 'category' not in exclude:
cat = np.array(['below normal', 'near normal', 'above normal'], dtype='<U12')
category = xr.DataArray(cat, dims='category', coords={'category': cat}, name='category')
assert_equal(category, preds_test['category'])
# size
if 'size' not in exclude:
from dask.utils import format_bytes
......@@ -66,7 +66,7 @@ def assert_predictions_2020(preds_test, exclude='week'):
# todo: refine for dtypes
assert size_in_MB > 30
assert size_in_MB < 250
# no other dims
if 'dims' in exclude:
assert set(preds_test.dims) - {'category', 'forecast_time', 'latitude', 'lead_time', 'longitude'} == set()
......@@ -97,20 +97,19 @@ if __name__ == "__main__":
# climatology forecast rps
rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute()
# submission rpss wrt climatology
rpss = (1 - rps_ML / rps_clim)
## RPSS
# penalize # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7
expect = obs_p.sum('category')
expect = expect.where(expect > 0.98).where(expect < 1.02) # should be True if not all NaN
# https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7
# penalize
penalize = obs_p.where(fct_p!=1, other=-10).mean('category')
rpss = rpss.where(penalize!=0,other=-10)
# https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/50
rps_ML = rps_ML.where(expect, other=2) # assign RPS=2 where value was expected but NaN found
# following Weigel 2007: https://doi.org/10.1175/MWR3280.1
rpss = 1 - (rps_ML.mean('forecast_time') / rps_clim.mean('forecast_time'))
# clip
rpss = rpss.clip(-10, 1)
# average over all forecasts
rpss = rpss.mean('forecast_time')
# weighted area mean
weights = np.cos(np.deg2rad(np.abs(rpss.latitude)))
# spatially weighted score averaged over lead_times and variables to one single value
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment