Newer
Older
import xarray as xr
import xskillscore as xs
import numpy as np
def assert_predictions_2020(preds_test, exclude='week'):
"""Check the variables, coordinates and dimensions of 2020 predictions."""
from xarray.testing import assert_equal # doesnt care about attrs but checks coords
if isinstance(exclude, str):
exclude = [exclude]
# is dataset
assert isinstance(preds_test, xr.Dataset)
# has both vars: tp and t2m
if 'data_vars' not in exclude:
assert 'tp' in preds_test.data_vars
assert 't2m' in preds_test.data_vars
# ignore week coord if not dim
if 'week' in exclude and 'week' in preds_test.coords and 'week' not in preds_test.dims:
preds_test = preds_test.drop('week')
if 'forecast_time' not in exclude:
d = pd.date_range(start='2020-01-02', freq='7D', periods=53)
forecast_time = xr.DataArray(d, dims='forecast_time', coords={'forecast_time':d}, name='forecast_time')
assert_equal(forecast_time, preds_test['forecast_time'])
# longitude
if 'longitude' not in exclude:
lon = np.arange(0., 360., 1.5)
longitude = xr.DataArray(lon, dims='longitude', coords={'longitude': lon}, name='longitude')
assert_equal(longitude, preds_test['longitude'])
# latitude
if 'latitude' not in exclude:
#lat = np.arange(-90., 90.1, 1.5)[::-1]
#latitude = xr.DataArray(lat, dims='latitude', coords={'latitude': lat}, name='latitude')
#assert_equal(latitude, preds_test['latitude'])
# above too strict, allow for not submitting south of 60S
assert (preds_test.latitude.diff('latitude')==-1.5).all()
assert 60 in preds_test.latitude
assert -60 in preds_test.latitude
if 'lead_time' not in exclude:
lead = [pd.Timedelta(f'{i} d') for i in [14, 28]]
lead_time = xr.DataArray(lead, dims='lead_time', coords={'lead_time': lead}, name='lead_time')
assert_equal(lead_time, preds_test['lead_time'])
if 'category' not in exclude:
cat = np.array(['below normal', 'near normal', 'above normal'], dtype='<U12')
category = xr.DataArray(cat, dims='category', coords={'category': cat}, name='category')
assert_equal(category, preds_test['category'])
if 'size' not in exclude:
from dask.utils import format_bytes
size_in_MB = float(format_bytes(preds_test.nbytes).split(' ')[0])
# todo: refine for dtypes
assert size_in_MB > 30
# no other dims
if 'dims' in exclude:
assert set(preds_test.dims) - {'category', 'forecast_time', 'latitude', 'lead_time', 'longitude'} == set()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("prediction", help="The netcdf file with predictions")
args = parser.parse_args()
cache_path = "scoring"
observations_terciled_fin = Path(f'{cache_path}/forecast-like-observations_2020_biweekly_terciled.nc')
# check inputs
assert_predictions_2020(obs_p)
assert_predictions_2020(fct_p)
# climatology forecast
clim_p = xr.DataArray([1/3, 1/3, 1/3], dims='category', coords={'category':['below normal', 'near normal', 'above normal']}).to_dataset(name='tp')
clim_p['t2m'] = clim_p['tp']
# submission rps
rps_ML = xs.rps(obs_p, fct_p, category_edges=None, dim=[], input_distributions='p').compute()
# climatology forecast rps
rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute()
## RPSS
# penalize # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7
expect = obs_p.sum('category')
expect = expect.where(expect > 0.98).where(expect < 1.02) # should be True if not all NaN
# https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/50
rps_ML = rps_ML.where(expect, other=2) # assign RPS=2 where value was expected but NaN found
# following Weigel 2007: https://doi.org/10.1175/MWR3280.1
rpss = 1 - (rps_ML.mean('forecast_time') / rps_clim.mean('forecast_time'))
weights = np.cos(np.deg2rad(np.abs(rpss.latitude)))
# spatially weighted score averaged over lead_times and variables to one single value
scores = rpss.sel(latitude=slice(None, -60)).weighted(weights).mean('latitude').mean('longitude')
scores = scores.to_array().mean(['lead_time', 'variable']).reset_coords(drop=True)