From dbb8d8456cc84248e9135699a245917d5f41dded Mon Sep 17 00:00:00 2001 From: AS <aaron.spring@mpimet.mpg.de> Date: Thu, 21 Oct 2021 18:34:28 +0200 Subject: [PATCH] use average RPS for RPSS --- scoring/scoring_script.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/scoring/scoring_script.py b/scoring/scoring_script.py index 4fd4534e6..09e9f3fbe 100644 --- a/scoring/scoring_script.py +++ b/scoring/scoring_script.py @@ -11,7 +11,7 @@ def assert_predictions_2020(preds_test, exclude='week'): from xarray.testing import assert_equal # doesnt care about attrs but checks coords if isinstance(exclude, str): exclude = [exclude] - + # is dataset assert isinstance(preds_test, xr.Dataset) @@ -19,12 +19,12 @@ def assert_predictions_2020(preds_test, exclude='week'): if 'data_vars' not in exclude: assert 'tp' in preds_test.data_vars assert 't2m' in preds_test.data_vars - + ## coords # ignore week coord if not dim if 'week' in exclude and 'week' in preds_test.coords and 'week' not in preds_test.dims: preds_test = preds_test.drop('week') - + # forecast_time if 'forecast_time' not in exclude: d = pd.date_range(start='2020-01-02', freq='7D', periods=53) @@ -52,13 +52,13 @@ def assert_predictions_2020(preds_test, exclude='week'): lead = [pd.Timedelta(f'{i} d') for i in [14, 28]] lead_time = xr.DataArray(lead, dims='lead_time', coords={'lead_time': lead}, name='lead_time') assert_equal(lead_time, preds_test['lead_time']) - + # category if 'category' not in exclude: cat = np.array(['below normal', 'near normal', 'above normal'], dtype='<U12') category = xr.DataArray(cat, dims='category', coords={'category': cat}, name='category') assert_equal(category, preds_test['category']) - + # size if 'size' not in exclude: from dask.utils import format_bytes @@ -66,7 +66,7 @@ def assert_predictions_2020(preds_test, exclude='week'): # todo: refine for dtypes assert size_in_MB > 30 assert size_in_MB < 250 - + # no other dims if 'dims' in exclude: assert set(preds_test.dims) - {'category', 'forecast_time', 'latitude', 'lead_time', 'longitude'} == set() @@ -97,20 +97,19 @@ if __name__ == "__main__": # climatology forecast rps rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute() - # submission rpss wrt climatology - rpss = (1 - rps_ML / rps_clim) + ## RPSS + # penalize # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7 + expect = obs_p.sum('category') + expect = expect.where(expect > 0.98).where(expect < 1.02) # should be True if not all NaN - # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7 - # penalize - penalize = obs_p.where(fct_p!=1, other=-10).mean('category') - rpss = rpss.where(penalize!=0,other=-10) + # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/50 + rps_ML = rps_ML.where(expect, other=2) # assign RPS=2 where value was expected but NaN found + # following Weigel 2007: https://doi.org/10.1175/MWR3280.1 + rpss = 1 - (rps_ML.groupby('forecast_time.year').mean() / rps_clim.groupby('forecast_time.year').mean()) # clip rpss = rpss.clip(-10, 1) - # average over all forecasts - rpss = rpss.mean('forecast_time') - # weighted area mean weights = np.cos(np.deg2rad(np.abs(rpss.latitude))) # spatially weighted score averaged over lead_times and variables to one single value -- GitLab